Fixup non-const identifiers in type constructor.
As long as a type constructor is not global the values can be non-const
which means they don't have to be constructors. This CL fixes an issue
where we incorrectly assumed the value was a constructor.
Change-Id: Ib1661830cbb14298ea9254145edd60b74e0dee1d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/20344
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 820accd..bdaba18 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -362,7 +362,13 @@
bool TypeDeterminer::DetermineConstructor(ast::ConstructorExpression* expr) {
if (expr->IsTypeConstructor()) {
- expr->set_result_type(expr->AsTypeConstructor()->type());
+ auto* ty = expr->AsTypeConstructor();
+ for (const auto& value : ty->values()) {
+ if (!DetermineResultType(value.get())) {
+ return false;
+ }
+ }
+ expr->set_result_type(ty->type());
} else {
expr->set_result_type(expr->AsScalarConstructor()->literal()->type());
}
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 47737b4..3fb12c4 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -172,10 +172,7 @@
}
// If the thing we're assigning is a pointer then we must load it first.
- if (assign->rhs()->result_type()->IsPointer()) {
- rhs_id =
- GenerateLoad(assign->rhs()->result_type()->UnwrapPtrIfNeeded(), rhs_id);
- }
+ rhs_id = GenerateLoadIfNeeded(assign->rhs()->result_type(), rhs_id);
GenerateStore(lhs_id, rhs_id);
return true;
@@ -520,7 +517,11 @@
return 0;
}
-uint32_t Builder::GenerateLoad(ast::type::Type* type, uint32_t id) {
+uint32_t Builder::GenerateLoadIfNeeded(ast::type::Type* type, uint32_t id) {
+ if (!type->IsPointer()) {
+ return id;
+ }
+
auto type_id = GenerateTypeIfNeeded(type->UnwrapPtrIfNeeded());
auto result = result_op();
auto result_id = result.to_i();
@@ -599,8 +600,13 @@
}
constructor_is_const = false;
}
- auto id =
- GenerateConstructorExpression(e->AsConstructor(), is_global_init);
+ uint32_t id = 0;
+ if (constructor_is_const) {
+ id = GenerateConstructorExpression(e->AsConstructor(), is_global_init);
+ } else {
+ id = GenerateExpression(e.get());
+ id = GenerateLoadIfNeeded(e->result_type(), id);
+ }
if (id == 0) {
return 0;
}
@@ -676,17 +682,13 @@
if (lhs_id == 0) {
return 0;
}
- if (expr->lhs()->result_type()->IsPointer()) {
- lhs_id = GenerateLoad(expr->lhs()->result_type(), lhs_id);
- }
+ lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id);
auto rhs_id = GenerateExpression(expr->rhs());
if (rhs_id == 0) {
return 0;
}
- if (expr->rhs()->result_type()->IsPointer()) {
- rhs_id = GenerateLoad(expr->rhs()->result_type(), rhs_id);
- }
+ rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id);
auto result = result_op();
auto result_id = result.to_i();
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index c30f408..a425cb1 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -239,8 +239,8 @@
/// Geneates an OpLoad
/// @param type the type to load
/// @param id the variable id to load
- /// @returns the ID of the loaded value
- uint32_t GenerateLoad(ast::type::Type* type, uint32_t id);
+ /// @returns the ID of the loaded value or |id| if type is not a pointer
+ uint32_t GenerateLoadIfNeeded(ast::type::Type* type, uint32_t id);
/// Geneates an OpStore
/// @param to the ID to store too
/// @param from the ID to store from
diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc
index 0bc3da1..efa3875 100644
--- a/src/writer/spirv/builder_constructor_expression_test.cc
+++ b/src/writer/spirv/builder_constructor_expression_test.cc
@@ -19,10 +19,13 @@
#include "spirv/unified1/spirv.hpp11"
#include "src/ast/binary_expression.h"
#include "src/ast/float_literal.h"
+#include "src/ast/identifier_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
+#include "src/context.h"
+#include "src/type_determiner.h"
#include "src/writer/spirv/builder.h"
#include "src/writer/spirv/spv_dump.h"
@@ -75,6 +78,48 @@
)");
}
+TEST_F(BuilderTest, Constructor_Type_NonConstructorParam) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec(&f32, 2);
+
+ auto var = std::make_unique<ast::Variable>(
+ "ident", ast::StorageClass::kFunction, &f32);
+
+ ast::ExpressionList vals;
+ vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals.push_back(std::make_unique<ast::IdentifierExpression>("ident"));
+
+ ast::TypeConstructorExpression t(&vec, std::move(vals));
+
+ Context ctx;
+ ast::Module mod;
+ TypeDeterminer td(&ctx, &mod);
+ td.RegisterVariableForTesting(var.get());
+ EXPECT_TRUE(td.DetermineResultType(&t)) << td.error();
+
+ Builder b(&mod);
+ b.push_function(Function{});
+ ASSERT_TRUE(b.GenerateFunctionVariable(var.get())) << b.error();
+
+ EXPECT_EQ(b.GenerateConstructorExpression(&t, false), 7u);
+ ASSERT_FALSE(b.has_error()) << b.error();
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32
+%2 = OpTypePointer Function %3
+%4 = OpTypeVector %3 2
+%5 = OpConstant %3 1
+)");
+ EXPECT_EQ(DumpInstructions(b.functions()[0].variables()),
+ R"(%1 = OpVariable %2 Function
+)");
+
+ EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+ R"(%6 = OpLoad %3 %1
+%7 = OpCompositeConstruct %4 %5 %6
+)");
+}
+
TEST_F(BuilderTest, Constructor_Type_Dedups) {
ast::type::F32Type f32;
ast::type::VectorType vec(&f32, 3);