spirv-reader: Fix signedness for extended instructions
SClamp is the only implemented instruction that is affected, so far
Bug: tint:405
Change-Id: I21c1cdd3e70fc3a64046f0473569ba906048cd37
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35240
Auto-Submit: David Neto <dneto@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 956aec6..61d35ec 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -2862,7 +2862,7 @@
const spvtools::opt::Instruction& inst,
uint32_t operand_index) {
auto expr = this->MakeExpression(inst.GetSingleWordInOperand(operand_index));
- return parser_impl_.RectifyOperandSignedness(inst.opcode(), std::move(expr));
+ return parser_impl_.RectifyOperandSignedness(inst, std::move(expr));
}
TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
@@ -2883,7 +2883,7 @@
auto* binary_expr =
create<ast::BinaryExpression>(binary_op, arg0.expr, arg1.expr);
TypedExpression result{ast_type, binary_expr};
- return parser_impl_.RectifyForcedResultType(result, opcode, arg0.type);
+ return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
auto unary_op = ast::UnaryOp::kNegation;
@@ -2891,7 +2891,7 @@
auto arg0 = MakeOperand(inst, 0);
auto* unary_expr = create<ast::UnaryOpExpression>(unary_op, arg0.expr);
TypedExpression result{ast_type, unary_expr};
- return parser_impl_.RectifyForcedResultType(result, opcode, arg0.type);
+ return parser_impl_.RectifyForcedResultType(result, inst, arg0.type);
}
const char* unary_builtin_name = GetUnaryBuiltInFunctionName(opcode);
@@ -3002,13 +3002,20 @@
auto* func = create<ast::IdentifierExpression>(name);
ast::ExpressionList operands;
+ ast::type::Type* first_operand_type = nullptr;
// All parameters to GLSL.std.450 extended instructions are IDs.
for (uint32_t iarg = 2; iarg < inst.NumInOperands(); ++iarg) {
- operands.emplace_back(MakeOperand(inst, iarg).expr);
+ TypedExpression operand = MakeOperand(inst, iarg);
+ if (first_operand_type == nullptr) {
+ first_operand_type = operand.type;
+ }
+ operands.emplace_back(operand.expr);
}
auto* ast_type = parser_impl_.ConvertType(inst.type_id());
auto* call = create<ast::CallExpression>(func, std::move(operands));
- return {ast_type, call};
+ TypedExpression call_expr{ast_type, call};
+ return parser_impl_.RectifyForcedResultType(call_expr, inst,
+ first_operand_type);
}
ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) {
diff --git a/src/reader/spirv/function_glsl_std_450_test.cc b/src/reader/spirv/function_glsl_std_450_test.cc
index 10c16da..203d357 100644
--- a/src/reader/spirv/function_glsl_std_450_test.cc
+++ b/src/reader/spirv/function_glsl_std_450_test.cc
@@ -826,6 +826,66 @@
SpvParserTest_GlslStd450_Uinting_UintingUintingUinting,
::testing::Values(GlslStd450Case{"UClamp", "clamp"}));
+TEST_F(SpvParserTest, RectifyOperandsAndResult_GLSLstd450SClamp) {
+ const auto assembly = Preamble() + R"(
+ %1 = OpExtInst %uint %glsl SClamp %u1 %i2 %u3
+ %2 = OpExtInst %v2uint %glsl SClamp %v2u1 %v2i2 %v2u3
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ auto body = ToString(fe.ast_body());
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_1
+ none
+ __u32
+ {
+ Bitcast[not set]<__u32>{
+ Call[not set]{
+ Identifier[not set]{clamp}
+ (
+ Bitcast[not set]<__i32>{
+ Identifier[not set]{u1}
+ }
+ Identifier[not set]{i2}
+ Bitcast[not set]<__i32>{
+ Identifier[not set]{u3}
+ }
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+ EXPECT_THAT(body, HasSubstr(R"(
+ VariableConst{
+ x_2
+ none
+ __vec_2__u32
+ {
+ Bitcast[not set]<__vec_2__u32>{
+ Call[not set]{
+ Identifier[not set]{clamp}
+ (
+ Bitcast[not set]<__vec_2__i32>{
+ Identifier[not set]{v2u1}
+ }
+ Identifier[not set]{v2i2}
+ Bitcast[not set]<__vec_2__i32>{
+ Identifier[not set]{v2u3}
+ }
+ )
+ }
+ }
+ }
+ })"))
+ << body;
+}
+
} // namespace
} // namespace spirv
} // namespace reader
diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc
index 12908b4..3543a77 100644
--- a/src/reader/spirv/parser_impl.cc
+++ b/src/reader/spirv/parser_impl.cc
@@ -159,6 +159,23 @@
return false;
}
+// Returns true if the GLSL extended instruction expects operands to be signed.
+// @param extended_opcode GLSL.std.450 opcode
+// @returns true if all operands must be signed integral type
+bool AssumesSignedOperands(GLSLstd450 extended_opcode) {
+ switch (extended_opcode) {
+ case GLSLstd450SAbs:
+ case GLSLstd450SSign:
+ case GLSLstd450SMin:
+ case GLSLstd450SMax:
+ case GLSLstd450SClamp:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
// Returns true if the opcode operates as if its operands are unsigned integral.
bool AssumesUnsignedOperands(SpvOp opcode) {
switch (opcode) {
@@ -176,6 +193,22 @@
return false;
}
+// Returns true if the GLSL extended instruction expects operands to be
+// unsigned.
+// @param extended_opcode GLSL.std.450 opcode
+// @returns true if all operands must be unsigned integral type
+bool AssumesUnsignedOperands(GLSLstd450 extended_opcode) {
+ switch (extended_opcode) {
+ case GLSLstd450UMin:
+ case GLSLstd450UMax:
+ case GLSLstd450UClamp:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
// Returns true if the operation is binary, and the WGSL operation requires
// the signedness of the result to match the signedness of the first operand.
bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) {
@@ -190,6 +223,29 @@
return false;
}
+// Returns true if the extended instruction requires the signedness of the
+// result to match the signedness of the first operand to the operation.
+// @param extended_opcode GLSL.std.450 opcode
+// @returns true if the result type must match the first operand type.
+bool AssumesResultSignednessMatchesFirstOperand(GLSLstd450 extended_opcode) {
+ switch (extended_opcode) {
+ case GLSLstd450SAbs:
+ case GLSLstd450SSign:
+ case GLSLstd450SMin:
+ case GLSLstd450SMax:
+ case GLSLstd450SClamp:
+ case GLSLstd450UMin:
+ case GLSLstd450UMax:
+ case GLSLstd450UClamp:
+ // TODO(dneto): FindSMsb?
+ // TODO(dneto): FindUMsb?
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
} // namespace
ParserImpl::ParserImpl(const std::vector<uint32_t>& spv_binary)
@@ -561,6 +617,12 @@
return true;
}
+bool ParserImpl::IsGlslExtendedInstruction(
+ const spvtools::opt::Instruction& inst) const {
+ return (inst.opcode() == SpvOpExtInst) &&
+ (glsl_std_450_imports_.count(inst.GetSingleWordInOperand(0)) > 0);
+}
+
bool ParserImpl::RegisterUserAndStructMemberNames() {
if (!success_) {
return false;
@@ -1389,10 +1451,21 @@
return nullptr;
}
-TypedExpression ParserImpl::RectifyOperandSignedness(SpvOp op,
- TypedExpression&& expr) {
- const bool requires_signed = AssumesSignedOperands(op);
- const bool requires_unsigned = AssumesUnsignedOperands(op);
+TypedExpression ParserImpl::RectifyOperandSignedness(
+ const spvtools::opt::Instruction& inst,
+ TypedExpression&& expr) {
+ bool requires_signed = false;
+ bool requires_unsigned = false;
+ if (IsGlslExtendedInstruction(inst)) {
+ const auto extended_opcode =
+ static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
+ requires_signed = AssumesSignedOperands(extended_opcode);
+ requires_unsigned = AssumesUnsignedOperands(extended_opcode);
+ } else {
+ const auto opcode = inst.opcode();
+ requires_signed = AssumesSignedOperands(opcode);
+ requires_unsigned = AssumesUnsignedOperands(opcode);
+ }
if (!requires_signed && !requires_unsigned) {
// No conversion is required, assuming our tables are complete.
return std::move(expr);
@@ -1425,14 +1498,26 @@
}
ast::type::Type* ParserImpl::ForcedResultType(
- SpvOp op,
+ const spvtools::opt::Instruction& inst,
ast::type::Type* first_operand_type) {
- const bool binary_match_first_operand =
- AssumesResultSignednessMatchesBinaryFirstOperand(op);
- const bool unary_match_operand = (op == SpvOpSNegate) || (op == SpvOpNot);
- if (binary_match_first_operand || unary_match_operand) {
+ const auto opcode = inst.opcode();
+ if ((opcode == SpvOpSNegate) || (opcode == SpvOpNot)) {
+ // The unary operation cases that force the result type to match the
+ // first operand type.
return first_operand_type;
}
+ if (AssumesResultSignednessMatchesBinaryFirstOperand(opcode)) {
+ // The binary operation cases that force the result type to match
+ // the first operand type.
+ return first_operand_type;
+ }
+ if (IsGlslExtendedInstruction(inst)) {
+ const auto extended_opcode =
+ static_cast<GLSLstd450>(inst.GetSingleWordInOperand(1));
+ if (AssumesResultSignednessMatchesFirstOperand(extended_opcode)) {
+ return first_operand_type;
+ }
+ }
return nullptr;
}
@@ -1474,9 +1559,9 @@
TypedExpression ParserImpl::RectifyForcedResultType(
TypedExpression expr,
- SpvOp op,
+ const spvtools::opt::Instruction& inst,
ast::type::Type* first_operand_type) {
- auto* forced_result_ty = ForcedResultType(op, first_operand_type);
+ auto* forced_result_ty = ForcedResultType(inst, first_operand_type);
if ((forced_result_ty == nullptr) || (forced_result_ty == expr.type)) {
return expr;
}
diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h
index 346fecb..f5b9689 100644
--- a/src/reader/spirv/parser_impl.h
+++ b/src/reader/spirv/parser_impl.h
@@ -235,6 +235,12 @@
/// @returns true if parser is still successful.
bool RegisterExtendedInstructionImports();
+ // Returns true when the given instruction is an extended instruction
+ // for GLSL.std.450.
+ // @param inst a SPIR-V instruction
+ // @returns true if its an SpvOpExtInst for GLSL.std.450
+ bool IsGlslExtendedInstruction(const spvtools::opt::Instruction& inst) const;
+
/// Registers user names for SPIR-V objects, from OpName, and OpMemberName.
/// Also synthesizes struct field names. Ensures uniqueness for names for
/// SPIR-V IDs, and uniqueness of names of fields within any single struct.
@@ -301,25 +307,27 @@
ast::Expression* MakeNullValue(ast::type::Type* type);
/// Converts a given expression to the signedness demanded for an operand
- /// of the given SPIR-V opcode, if required. If the operation assumes
+ /// of the given SPIR-V instruction, if required. If the instruction assumes
/// signed integer operands, and `expr` is unsigned, then return an
/// as-cast expression converting it to signed. Otherwise, return
/// `expr` itself. Similarly, convert as required from unsigned
/// to signed. Assumes all SPIR-V types have been mapped to AST types.
- /// @param op the SPIR-V opcode
+ /// @param inst the SPIR-V instruction
/// @param expr an expression
/// @returns expr, or a cast of expr
- TypedExpression RectifyOperandSignedness(SpvOp op, TypedExpression&& expr);
+ TypedExpression RectifyOperandSignedness(
+ const spvtools::opt::Instruction& inst,
+ TypedExpression&& expr);
- /// Returns the "forced" result type for the given SPIR-V opcode.
+ /// Returns the "forced" result type for the given SPIR-V instruction.
/// If the WGSL result type for an operation has a more strict rule than
/// requried by SPIR-V, then we say the result type is "forced". This occurs
/// for signed integer division (OpSDiv), for example, where the result type
/// in WGSL must match the operand types.
- /// @param op the SPIR-V opcode
+ /// @param inst the SPIR-V instruction
/// @param first_operand_type the AST type for the first operand.
/// @returns the forced AST result type, or nullptr if no forcing is required.
- ast::type::Type* ForcedResultType(SpvOp op,
+ ast::type::Type* ForcedResultType(const spvtools::opt::Instruction& inst,
ast::type::Type* first_operand_type);
/// Returns a signed integer scalar or vector type matching the shape (scalar,
@@ -343,12 +351,13 @@
/// from the expression's result type. Otherwise, returns the given expression
/// unchanged.
/// @param expr the expression to pass through or to wrap
- /// @param op the SPIR-V opcode
+ /// @param inst the SPIR-V instruction
/// @param first_operand_type the AST type for the first operand.
/// @returns the forced AST result type, or nullptr if no forcing is required.
- TypedExpression RectifyForcedResultType(TypedExpression expr,
- SpvOp op,
- ast::type::Type* first_operand_type);
+ TypedExpression RectifyForcedResultType(
+ TypedExpression expr,
+ const spvtools::opt::Instruction& inst,
+ ast::type::Type* first_operand_type);
/// @returns the registered boolean type.
ast::type::Type* Bool() const { return bool_type_; }