Add support for binary arithmetic expressions with mixed scalar and vector operands Bug: tint:376 Change-Id: I2994ff7394efa903050b470a850b41628d5b775c Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52324 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 3be647b..baceffa 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc
@@ -2097,23 +2097,35 @@ SetType(expr, lhs_type); return true; } + + // Binary arithmetic expressions with mixed scalar and vector operands + if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_type)) { + if (expr->IsModulo()) { + if (rhs_type->is_integer_scalar()) { + SetType(expr, lhs_type); + return true; + } + } else if (rhs_type->is_numeric_scalar()) { + SetType(expr, lhs_type); + return true; + } + } + if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_type)) { + if (expr->IsModulo()) { + if (lhs_type->is_integer_scalar()) { + SetType(expr, rhs_type); + return true; + } + } else if (lhs_type->is_numeric_scalar()) { + SetType(expr, rhs_type); + return true; + } + } } - // Binary arithmetic expressions with mixed scalar, vector, and matrix - // operands + // Matrix arithmetic + // TODO(amaiorano): matrix-matrix addition and subtraction if (expr->IsMultiply()) { - // Multiplication of a vector and a scalar - if (lhs_type->Is<F32>() && rhs_vec_elem_type && - rhs_vec_elem_type->Is<F32>()) { - SetType(expr, rhs_type); - return true; - } - if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && - rhs_type->Is<F32>()) { - SetType(expr, lhs_type); - return true; - } - auto* lhs_mat = lhs_type->As<Matrix>(); auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr; auto* rhs_mat = rhs_type->As<Matrix>();
diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index 7a088c7..06b6ce1 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc
@@ -1298,16 +1298,52 @@ Params{Op::kDivide, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, Params{Op::kModulo, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, - // Binary arithmetic expressions with mixed scalar, vector, and matrix - // operands - Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, - Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, + // Binary arithmetic expressions with mixed scalar and vector operands + Params{Op::kAdd, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, + Params{Op::kSubtract, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, + Params{Op::kMultiply, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, + Params{Op::kDivide, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, + Params{Op::kModulo, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>}, + Params{Op::kAdd, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, + Params{Op::kSubtract, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, + Params{Op::kMultiply, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, + Params{Op::kDivide, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, + Params{Op::kModulo, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>}, + + Params{Op::kAdd, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, + Params{Op::kSubtract, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, + Params{Op::kMultiply, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, + Params{Op::kDivide, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, + Params{Op::kModulo, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>}, + + Params{Op::kAdd, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, + Params{Op::kSubtract, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, + Params{Op::kMultiply, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, + Params{Op::kDivide, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, + Params{Op::kModulo, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>}, + + Params{Op::kAdd, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, + Params{Op::kSubtract, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, + Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, + Params{Op::kDivide, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, + // NOTE: no kModulo for ast_vec3<f32>, ast_f32 + // Params{Op::kModulo, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>}, + + Params{Op::kAdd, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, + Params{Op::kSubtract, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, + Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, + Params{Op::kDivide, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, + // NOTE: no kModulo for ast_f32, ast_vec3<f32> + // Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, + + // Matrix arithmetic Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>}, Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>}, Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>}, Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, + // TODO(amaiorano): add mat+mat and mat-mat Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 41d9f1b..6f96b21 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc
@@ -1719,6 +1719,31 @@ return result_id; } +uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) { + // Create a new vector to splat scalar into + auto splat_vector = result_op(); + auto* splat_vector_type = + builder_.create<sem::Pointer>(vec_type, ast::StorageClass::kFunction); + push_function_var( + {Operand::Int(GenerateTypeIfNeeded(splat_vector_type)), splat_vector, + Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)), + Operand::Int(GenerateConstantNullIfNeeded(vec_type))}); + + // Splat scalar into vector + auto splat_result = result_op(); + OperandList ops; + ops.push_back(Operand::Int(GenerateTypeIfNeeded(vec_type))); + ops.push_back(splat_result); + for (size_t i = 0; i < vec_type->As<sem::Vector>()->size(); ++i) { + ops.push_back(Operand::Int(scalar_id)); + } + if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) { + return 0; + } + + return splat_result.to_i(); +} + uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { // There is special logic for short circuiting operators. if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { @@ -1749,6 +1774,33 @@ // should have been rejected by validation. auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef(); auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef(); + + // For vector-scalar arithmetic operations, splat scalar into a vector. We + // skip this for multiply as we can use OpVectorTimesScalar. + const bool is_float_scalar_vector_multiply = + expr->IsMultiply() && + ((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) || + (lhs_type->is_float_vector() && rhs_type->is_float_scalar())); + + if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) { + if (lhs_type->Is<sem::Vector>() && rhs_type->is_numeric_scalar()) { + uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type); + if (splat_vector_id == 0) { + return 0; + } + rhs_id = splat_vector_id; + rhs_type = lhs_type; + + } else if (lhs_type->is_numeric_scalar() && rhs_type->Is<sem::Vector>()) { + uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type); + if (splat_vector_id == 0) { + return 0; + } + lhs_id = splat_vector_id; + lhs_type = rhs_type; + } + } + bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 0774e7a..3daaa18 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h
@@ -473,6 +473,13 @@ /// @returns true if the vector was successfully generated bool GenerateVectorType(const sem::Vector* vec, const Operand& result); + /// Generates instructions to splat `scalar_id` into a vector of type + /// `vec_type` + /// @param scalar_id scalar to splat + /// @param vec_type type of vector + /// @returns id of the new vector + uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type); + /// Converts AST image format to SPIR-V and pushes an appropriate capability. /// @param format AST image format type /// @returns SPIR-V image format type
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc index bdf19a2..e15d4ec 100644 --- a/src/writer/spirv/builder_binary_expression_test.cc +++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -863,6 +863,222 @@ )"); } +enum class Type { f32, i32, u32 }; +ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) { + switch (type) { + case Type::f32: + return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f); + case Type::i32: + return builder->vec3<ProgramBuilder::i32>(1, 1, 1); + case Type::u32: + return builder->vec3<ProgramBuilder::u32>(1u, 1u, 1u); + } + return nullptr; +} +ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) { + switch (type) { + case Type::f32: + return builder->Expr(1.f); + case Type::i32: + return builder->Expr(1); + case Type::u32: + return builder->Expr(1u); + } + return nullptr; +} +std::string OpTypeDecl(Type type) { + switch (type) { + case Type::f32: + return "OpTypeFloat 32"; + case Type::i32: + return "OpTypeInt 32 1"; + case Type::u32: + return "OpTypeInt 32 0"; + } + return {}; +} + +struct Param { + Type type; + ast::BinaryOp op; + std::string name; +}; + +using MixedBinaryArithTest = TestParamHelper<Param>; +TEST_P(MixedBinaryArithTest, VectorScalar) { + auto& param = GetParam(); + + ast::Expression* lhs = MakeVectorExpr(this, param.type); + ast::Expression* rhs = MakeScalarExpr(this, param.type); + std::string op_type_decl = OpTypeDecl(param.type); + + auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + ASSERT_TRUE(b.Build()) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%6 = )" + op_type_decl + R"( +%5 = OpTypeVector %6 3 +%7 = OpConstant %6 1 +%8 = OpConstantComposite %5 %7 %7 %7 +%11 = OpTypePointer Function %5 +%12 = OpConstantNull %5 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%10 = OpVariable %11 Function %12 +%13 = OpCompositeConstruct %5 %7 %7 %7 +%9 = )" + param.name + R"( %5 %8 %13 +OpReturn +OpFunctionEnd +)"); + + Validate(b); +} +TEST_P(MixedBinaryArithTest, ScalarVector) { + auto& param = GetParam(); + + ast::Expression* lhs = MakeScalarExpr(this, param.type); + ast::Expression* rhs = MakeVectorExpr(this, param.type); + std::string op_type_decl = OpTypeDecl(param.type); + + auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + ASSERT_TRUE(b.Build()) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%5 = )" + op_type_decl + R"( +%6 = OpConstant %5 1 +%7 = OpTypeVector %5 3 +%8 = OpConstantComposite %7 %6 %6 %6 +%11 = OpTypePointer Function %7 +%12 = OpConstantNull %7 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%10 = OpVariable %11 Function %12 +%13 = OpCompositeConstruct %7 %6 %6 %6 +%9 = )" + param.name + R"( %7 %13 %8 +OpReturn +OpFunctionEnd +)"); + + Validate(b); +} +INSTANTIATE_TEST_SUITE_P( + BuilderTest, + MixedBinaryArithTest, + testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"}, + Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"}, + // NOTE: Modulo not allowed on mixed float scalar-vector + // Param{Type::f32, ast::BinaryOp::kModulo, "OpFMod"}, + // NOTE: We test f32 multiplies separately as we emit + // OpVectorTimesScalar for this case + // Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"}, + Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"}, + + Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"}, + Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"}, + Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"}, + Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"}, + Param{Type::i32, ast::BinaryOp::kSubtract, "OpISub"}, + + Param{Type::u32, ast::BinaryOp::kAdd, "OpIAdd"}, + Param{Type::u32, ast::BinaryOp::kDivide, "OpUDiv"}, + Param{Type::u32, ast::BinaryOp::kModulo, "OpUMod"}, + Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"}, + Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"})); + +using MixedBinaryArithMultiplyTest = TestParamHelper<Param>; +TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) { + auto& param = GetParam(); + + ast::Expression* lhs = MakeVectorExpr(this, param.type); + ast::Expression* rhs = MakeScalarExpr(this, param.type); + std::string op_type_decl = OpTypeDecl(param.type); + + auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + ASSERT_TRUE(b.Build()) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%6 = )" + op_type_decl + R"( +%5 = OpTypeVector %6 3 +%7 = OpConstant %6 1 +%8 = OpConstantComposite %5 %7 %7 %7 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%9 = OpVectorTimesScalar %5 %8 %7 +OpReturn +OpFunctionEnd +)"); + + Validate(b); +} +TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) { + auto& param = GetParam(); + + ast::Expression* lhs = MakeScalarExpr(this, param.type); + ast::Expression* rhs = MakeVectorExpr(this, param.type); + std::string op_type_decl = OpTypeDecl(param.type); + + auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + ASSERT_TRUE(b.Build()) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%5 = )" + op_type_decl + R"( +%6 = OpConstant %5 1 +%7 = OpTypeVector %5 3 +%8 = OpConstantComposite %7 %6 %6 %6 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%9 = OpVectorTimesScalar %7 %8 %6 +OpReturn +OpFunctionEnd +)"); + + Validate(b); +} +INSTANTIATE_TEST_SUITE_P(BuilderTest, + MixedBinaryArithMultiplyTest, + testing::Values(Param{ + Type::f32, ast::BinaryOp::kMultiply, "OpFMul"})); + } // namespace } // namespace spirv } // namespace writer
diff --git a/test/expressions/binary_expressions.wgsl b/test/expressions/binary_expressions.wgsl new file mode 100644 index 0000000..a58eece --- /dev/null +++ b/test/expressions/binary_expressions.wgsl
@@ -0,0 +1,70 @@ +fn vector_scalar_f32() { + var v : vec3<f32>; + var s : f32; + var r : vec3<f32>; + r = v + s; + r = v - s; + r = v * s; + r = v / s; + //r = v % s; +} + +fn vector_scalar_i32() { + var v : vec3<i32>; + var s : i32; + var r : vec3<i32>; + r = v + s; + r = v - s; + r = v * s; + r = v / s; + r = v % s; +} + +fn vector_scalar_u32() { + var v : vec3<u32>; + var s : u32; + var r : vec3<u32>; + r = v + s; + r = v - s; + r = v * s; + r = v / s; + r = v % s; +} + +fn scalar_vector_f32() { + var v : vec3<f32>; + var s : f32; + var r : vec3<f32>; + r = s + v; + r = s - v; + r = s * v; + r = s / v; + //r = s % v; +} + +fn scalar_vector_i32() { + var v : vec3<i32>; + var s : i32; + var r : vec3<i32>; + r = s + v; + r = s - v; + r = s * v; + r = s / v; + r = s % v; +} + +fn scalar_vector_u32() { + var v : vec3<u32>; + var s : u32; + var r : vec3<u32>; + r = s + v; + r = s - v; + r = s * v; + r = s / v; + r = s % v; +} + +[[stage(fragment)]] +fn main() -> [[location(0)]] vec4<f32> { + return vec4<f32>(0.0,0.0,0.0,0.0); +}
diff --git a/test/expressions/binary_expressions.wgsl.expected.hlsl b/test/expressions/binary_expressions.wgsl.expected.hlsl new file mode 100644 index 0000000..01f2d2b --- /dev/null +++ b/test/expressions/binary_expressions.wgsl.expected.hlsl
@@ -0,0 +1,73 @@ +struct tint_symbol { + float4 value : SV_Target0; +}; + +void vector_scalar_f32() { + float3 v = float3(0.0f, 0.0f, 0.0f); + float s = 0.0f; + float3 r = float3(0.0f, 0.0f, 0.0f); + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); +} + +void vector_scalar_i32() { + int3 v = int3(0, 0, 0); + int s = 0; + int3 r = int3(0, 0, 0); + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); + r = (v % s); +} + +void vector_scalar_u32() { + uint3 v = uint3(0u, 0u, 0u); + uint s = 0u; + uint3 r = uint3(0u, 0u, 0u); + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); + r = (v % s); +} + +void scalar_vector_f32() { + float3 v = float3(0.0f, 0.0f, 0.0f); + float s = 0.0f; + float3 r = float3(0.0f, 0.0f, 0.0f); + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); +} + +void scalar_vector_i32() { + int3 v = int3(0, 0, 0); + int s = 0; + int3 r = int3(0, 0, 0); + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); + r = (s % v); +} + +void scalar_vector_u32() { + uint3 v = uint3(0u, 0u, 0u); + uint s = 0u; + uint3 r = uint3(0u, 0u, 0u); + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); + r = (s % v); +} + +tint_symbol main() { + const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)}; + return tint_symbol_1; +} +
diff --git a/test/expressions/binary_expressions.wgsl.expected.msl b/test/expressions/binary_expressions.wgsl.expected.msl new file mode 100644 index 0000000..5630940 --- /dev/null +++ b/test/expressions/binary_expressions.wgsl.expected.msl
@@ -0,0 +1,75 @@ +#include <metal_stdlib> + +using namespace metal; +struct tint_symbol_1 { + float4 value [[color(0)]]; +}; + +void vector_scalar_f32() { + float3 v = 0.0f; + float s = 0.0f; + float3 r = 0.0f; + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); +} + +void vector_scalar_i32() { + int3 v = 0; + int s = 0; + int3 r = 0; + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); + r = (v % s); +} + +void vector_scalar_u32() { + uint3 v = 0u; + uint s = 0u; + uint3 r = 0u; + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); + r = (v % s); +} + +void scalar_vector_f32() { + float3 v = 0.0f; + float s = 0.0f; + float3 r = 0.0f; + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); +} + +void scalar_vector_i32() { + int3 v = 0; + int s = 0; + int3 r = 0; + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); + r = (s % v); +} + +void scalar_vector_u32() { + uint3 v = 0u; + uint s = 0u; + uint3 r = 0u; + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); + r = (s % v); +} + +fragment tint_symbol_1 tint_symbol() { + return {float4(0.0f, 0.0f, 0.0f, 0.0f)}; +} +
diff --git a/test/expressions/binary_expressions.wgsl.expected.spvasm b/test/expressions/binary_expressions.wgsl.expected.spvasm new file mode 100644 index 0000000..445bb49 --- /dev/null +++ b/test/expressions/binary_expressions.wgsl.expected.spvasm
@@ -0,0 +1,282 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 200 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %tint_symbol_1 + OpExecutionMode %main OriginUpperLeft + OpName %tint_symbol_1 "tint_symbol_1" + OpName %vector_scalar_f32 "vector_scalar_f32" + OpName %v "v" + OpName %s "s" + OpName %r "r" + OpName %vector_scalar_i32 "vector_scalar_i32" + OpName %v_0 "v" + OpName %s_0 "s" + OpName %r_0 "r" + OpName %vector_scalar_u32 "vector_scalar_u32" + OpName %v_1 "v" + OpName %s_1 "s" + OpName %r_1 "r" + OpName %scalar_vector_f32 "scalar_vector_f32" + OpName %v_2 "v" + OpName %s_2 "s" + OpName %r_2 "r" + OpName %scalar_vector_i32 "scalar_vector_i32" + OpName %v_3 "v" + OpName %s_3 "s" + OpName %r_3 "r" + OpName %scalar_vector_u32 "scalar_vector_u32" + OpName %v_4 "v" + OpName %s_4 "s" + OpName %r_4 "r" + OpName %tint_symbol_2 "tint_symbol_2" + OpName %tint_symbol "tint_symbol" + OpName %main "main" + OpDecorate %tint_symbol_1 Location 0 + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 +%_ptr_Output_v4float = OpTypePointer Output %v4float + %5 = OpConstantNull %v4float +%tint_symbol_1 = OpVariable %_ptr_Output_v4float Output %5 + %void = OpTypeVoid + %6 = OpTypeFunction %void + %v3float = OpTypeVector %float 3 +%_ptr_Function_v3float = OpTypePointer Function %v3float + %13 = OpConstantNull %v3float +%_ptr_Function_float = OpTypePointer Function %float + %16 = OpConstantNull %float + %int = OpTypeInt 32 1 + %v3int = OpTypeVector %int 3 +%_ptr_Function_v3int = OpTypePointer Function %v3int + %42 = OpConstantNull %v3int +%_ptr_Function_int = OpTypePointer Function %int + %45 = OpConstantNull %int + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 +%_ptr_Function_v3uint = OpTypePointer Function %v3uint + %78 = OpConstantNull %v3uint +%_ptr_Function_uint = OpTypePointer Function %uint + %81 = OpConstantNull %uint + %191 = OpTypeFunction %void %v4float + %float_0 = OpConstant %float 0 + %199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 +%vector_scalar_f32 = OpFunction %void None %6 + %9 = OpLabel + %v = OpVariable %_ptr_Function_v3float Function %13 + %s = OpVariable %_ptr_Function_float Function %16 + %r = OpVariable %_ptr_Function_v3float Function %13 + %21 = OpVariable %_ptr_Function_v3float Function %13 + %26 = OpVariable %_ptr_Function_v3float Function %13 + %34 = OpVariable %_ptr_Function_v3float Function %13 + %18 = OpLoad %v3float %v + %19 = OpLoad %float %s + %22 = OpCompositeConstruct %v3float %19 %19 %19 + %20 = OpFAdd %v3float %18 %22 + OpStore %r %20 + %23 = OpLoad %v3float %v + %24 = OpLoad %float %s + %27 = OpCompositeConstruct %v3float %24 %24 %24 + %25 = OpFSub %v3float %23 %27 + OpStore %r %25 + %28 = OpLoad %v3float %v + %29 = OpLoad %float %s + %30 = OpVectorTimesScalar %v3float %28 %29 + OpStore %r %30 + %31 = OpLoad %v3float %v + %32 = OpLoad %float %s + %35 = OpCompositeConstruct %v3float %32 %32 %32 + %33 = OpFDiv %v3float %31 %35 + OpStore %r %33 + OpReturn + OpFunctionEnd +%vector_scalar_i32 = OpFunction %void None %6 + %37 = OpLabel + %v_0 = OpVariable %_ptr_Function_v3int Function %42 + %s_0 = OpVariable %_ptr_Function_int Function %45 + %r_0 = OpVariable %_ptr_Function_v3int Function %42 + %50 = OpVariable %_ptr_Function_v3int Function %42 + %55 = OpVariable %_ptr_Function_v3int Function %42 + %60 = OpVariable %_ptr_Function_v3int Function %42 + %65 = OpVariable %_ptr_Function_v3int Function %42 + %70 = OpVariable %_ptr_Function_v3int Function %42 + %47 = OpLoad %v3int %v_0 + %48 = OpLoad %int %s_0 + %51 = OpCompositeConstruct %v3int %48 %48 %48 + %49 = OpIAdd %v3int %47 %51 + OpStore %r_0 %49 + %52 = OpLoad %v3int %v_0 + %53 = OpLoad %int %s_0 + %56 = OpCompositeConstruct %v3int %53 %53 %53 + %54 = OpISub %v3int %52 %56 + OpStore %r_0 %54 + %57 = OpLoad %v3int %v_0 + %58 = OpLoad %int %s_0 + %61 = OpCompositeConstruct %v3int %58 %58 %58 + %59 = OpIMul %v3int %57 %61 + OpStore %r_0 %59 + %62 = OpLoad %v3int %v_0 + %63 = OpLoad %int %s_0 + %66 = OpCompositeConstruct %v3int %63 %63 %63 + %64 = OpSDiv %v3int %62 %66 + OpStore %r_0 %64 + %67 = OpLoad %v3int %v_0 + %68 = OpLoad %int %s_0 + %71 = OpCompositeConstruct %v3int %68 %68 %68 + %69 = OpSMod %v3int %67 %71 + OpStore %r_0 %69 + OpReturn + OpFunctionEnd +%vector_scalar_u32 = OpFunction %void None %6 + %73 = OpLabel + %v_1 = OpVariable %_ptr_Function_v3uint Function %78 + %s_1 = OpVariable %_ptr_Function_uint Function %81 + %r_1 = OpVariable %_ptr_Function_v3uint Function %78 + %86 = OpVariable %_ptr_Function_v3uint Function %78 + %91 = OpVariable %_ptr_Function_v3uint Function %78 + %96 = OpVariable %_ptr_Function_v3uint Function %78 + %101 = OpVariable %_ptr_Function_v3uint Function %78 + %106 = OpVariable %_ptr_Function_v3uint Function %78 + %83 = OpLoad %v3uint %v_1 + %84 = OpLoad %uint %s_1 + %87 = OpCompositeConstruct %v3uint %84 %84 %84 + %85 = OpIAdd %v3uint %83 %87 + OpStore %r_1 %85 + %88 = OpLoad %v3uint %v_1 + %89 = OpLoad %uint %s_1 + %92 = OpCompositeConstruct %v3uint %89 %89 %89 + %90 = OpISub %v3uint %88 %92 + OpStore %r_1 %90 + %93 = OpLoad %v3uint %v_1 + %94 = OpLoad %uint %s_1 + %97 = OpCompositeConstruct %v3uint %94 %94 %94 + %95 = OpIMul %v3uint %93 %97 + OpStore %r_1 %95 + %98 = OpLoad %v3uint %v_1 + %99 = OpLoad %uint %s_1 + %102 = OpCompositeConstruct %v3uint %99 %99 %99 + %100 = OpUDiv %v3uint %98 %102 + OpStore %r_1 %100 + %103 = OpLoad %v3uint %v_1 + %104 = OpLoad %uint %s_1 + %107 = OpCompositeConstruct %v3uint %104 %104 %104 + %105 = OpUMod %v3uint %103 %107 + OpStore %r_1 %105 + OpReturn + OpFunctionEnd +%scalar_vector_f32 = OpFunction %void None %6 + %109 = OpLabel + %v_2 = OpVariable %_ptr_Function_v3float Function %13 + %s_2 = OpVariable %_ptr_Function_float Function %16 + %r_2 = OpVariable %_ptr_Function_v3float Function %13 + %116 = OpVariable %_ptr_Function_v3float Function %13 + %121 = OpVariable %_ptr_Function_v3float Function %13 + %129 = OpVariable %_ptr_Function_v3float Function %13 + %113 = OpLoad %float %s_2 + %114 = OpLoad %v3float %v_2 + %117 = OpCompositeConstruct %v3float %113 %113 %113 + %115 = OpFAdd %v3float %117 %114 + OpStore %r_2 %115 + %118 = OpLoad %float %s_2 + %119 = OpLoad %v3float %v_2 + %122 = OpCompositeConstruct %v3float %118 %118 %118 + %120 = OpFSub %v3float %122 %119 + OpStore %r_2 %120 + %123 = OpLoad %float %s_2 + %124 = OpLoad %v3float %v_2 + %125 = OpVectorTimesScalar %v3float %124 %123 + OpStore %r_2 %125 + %126 = OpLoad %float %s_2 + %127 = OpLoad %v3float %v_2 + %130 = OpCompositeConstruct %v3float %126 %126 %126 + %128 = OpFDiv %v3float %130 %127 + OpStore %r_2 %128 + OpReturn + OpFunctionEnd +%scalar_vector_i32 = OpFunction %void None %6 + %132 = OpLabel + %v_3 = OpVariable %_ptr_Function_v3int Function %42 + %s_3 = OpVariable %_ptr_Function_int Function %45 + %r_3 = OpVariable %_ptr_Function_v3int Function %42 + %139 = OpVariable %_ptr_Function_v3int Function %42 + %144 = OpVariable %_ptr_Function_v3int Function %42 + %149 = OpVariable %_ptr_Function_v3int Function %42 + %154 = OpVariable %_ptr_Function_v3int Function %42 + %159 = OpVariable %_ptr_Function_v3int Function %42 + %136 = OpLoad %int %s_3 + %137 = OpLoad %v3int %v_3 + %140 = OpCompositeConstruct %v3int %136 %136 %136 + %138 = OpIAdd %v3int %140 %137 + OpStore %r_3 %138 + %141 = OpLoad %int %s_3 + %142 = OpLoad %v3int %v_3 + %145 = OpCompositeConstruct %v3int %141 %141 %141 + %143 = OpISub %v3int %145 %142 + OpStore %r_3 %143 + %146 = OpLoad %int %s_3 + %147 = OpLoad %v3int %v_3 + %150 = OpCompositeConstruct %v3int %146 %146 %146 + %148 = OpIMul %v3int %150 %147 + OpStore %r_3 %148 + %151 = OpLoad %int %s_3 + %152 = OpLoad %v3int %v_3 + %155 = OpCompositeConstruct %v3int %151 %151 %151 + %153 = OpSDiv %v3int %155 %152 + OpStore %r_3 %153 + %156 = OpLoad %int %s_3 + %157 = OpLoad %v3int %v_3 + %160 = OpCompositeConstruct %v3int %156 %156 %156 + %158 = OpSMod %v3int %160 %157 + OpStore %r_3 %158 + OpReturn + OpFunctionEnd +%scalar_vector_u32 = OpFunction %void None %6 + %162 = OpLabel + %v_4 = OpVariable %_ptr_Function_v3uint Function %78 + %s_4 = OpVariable %_ptr_Function_uint Function %81 + %r_4 = OpVariable %_ptr_Function_v3uint Function %78 + %169 = OpVariable %_ptr_Function_v3uint Function %78 + %174 = OpVariable %_ptr_Function_v3uint Function %78 + %179 = OpVariable %_ptr_Function_v3uint Function %78 + %184 = OpVariable %_ptr_Function_v3uint Function %78 + %189 = OpVariable %_ptr_Function_v3uint Function %78 + %166 = OpLoad %uint %s_4 + %167 = OpLoad %v3uint %v_4 + %170 = OpCompositeConstruct %v3uint %166 %166 %166 + %168 = OpIAdd %v3uint %170 %167 + OpStore %r_4 %168 + %171 = OpLoad %uint %s_4 + %172 = OpLoad %v3uint %v_4 + %175 = OpCompositeConstruct %v3uint %171 %171 %171 + %173 = OpISub %v3uint %175 %172 + OpStore %r_4 %173 + %176 = OpLoad %uint %s_4 + %177 = OpLoad %v3uint %v_4 + %180 = OpCompositeConstruct %v3uint %176 %176 %176 + %178 = OpIMul %v3uint %180 %177 + OpStore %r_4 %178 + %181 = OpLoad %uint %s_4 + %182 = OpLoad %v3uint %v_4 + %185 = OpCompositeConstruct %v3uint %181 %181 %181 + %183 = OpUDiv %v3uint %185 %182 + OpStore %r_4 %183 + %186 = OpLoad %uint %s_4 + %187 = OpLoad %v3uint %v_4 + %190 = OpCompositeConstruct %v3uint %186 %186 %186 + %188 = OpUMod %v3uint %190 %187 + OpStore %r_4 %188 + OpReturn + OpFunctionEnd +%tint_symbol_2 = OpFunction %void None %191 +%tint_symbol = OpFunctionParameter %v4float + %194 = OpLabel + OpStore %tint_symbol_1 %tint_symbol + OpReturn + OpFunctionEnd + %main = OpFunction %void None %6 + %196 = OpLabel + %197 = OpFunctionCall %void %tint_symbol_2 %199 + OpReturn + OpFunctionEnd
diff --git a/test/expressions/binary_expressions.wgsl.expected.wgsl b/test/expressions/binary_expressions.wgsl.expected.wgsl new file mode 100644 index 0000000..ec7e7d5 --- /dev/null +++ b/test/expressions/binary_expressions.wgsl.expected.wgsl
@@ -0,0 +1,68 @@ +fn vector_scalar_f32() { + var v : vec3<f32>; + var s : f32; + var r : vec3<f32>; + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); +} + +fn vector_scalar_i32() { + var v : vec3<i32>; + var s : i32; + var r : vec3<i32>; + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); + r = (v % s); +} + +fn vector_scalar_u32() { + var v : vec3<u32>; + var s : u32; + var r : vec3<u32>; + r = (v + s); + r = (v - s); + r = (v * s); + r = (v / s); + r = (v % s); +} + +fn scalar_vector_f32() { + var v : vec3<f32>; + var s : f32; + var r : vec3<f32>; + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); +} + +fn scalar_vector_i32() { + var v : vec3<i32>; + var s : i32; + var r : vec3<i32>; + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); + r = (s % v); +} + +fn scalar_vector_u32() { + var v : vec3<u32>; + var s : u32; + var r : vec3<u32>; + r = (s + v); + r = (s - v); + r = (s * v); + r = (s / v); + r = (s % v); +} + +[[stage(fragment)]] +fn main() -> [[location(0)]] vec4<f32> { + return vec4<f32>(0.0, 0.0, 0.0, 0.0); +}