Import Tint changes from Dawn
Changes:
- 4b70776aed5923ce09ed6025c1914317246e2531 tint: Fix transform errors when calling arrayLength() as ... by Ben Clayton <bclayton@google.com>
- 3526bc4f428aa1a6ef77aedd9b5b91d049bec0ec Remove GLSL override generation by dan sinclair <dsinclair@chromium.org>
- e1854b2d721dc07224291dcfa2c842ff749212ef Remove SPIR-V override generation by dan sinclair <dsinclair@chromium.org>
- f6a9404978b0030b79c5d35f38a13186e1c88dc6 Remove HLSL override generation by dan sinclair <dsinclair@chromium.org>
- 5432767a04d5f0df835b5c563256f9aaf6e24544 Remove MSL override generation by dan sinclair <dsinclair@chromium.org>
- 93df967003abc34dc283ffc575850492ce9d6e4c Update StructMember{Offset,Size}Attribute to expressions. by dan sinclair <dsinclair@chromium.org>
GitOrigin-RevId: 4b70776aed5923ce09ed6025c1914317246e2531
Change-Id: I0c4dfcc8c97a1bdbea626f207fa43ca3beaf8acd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/101940
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
diff --git a/src/tint/ast/struct_member_offset_attribute.cc b/src/tint/ast/struct_member_offset_attribute.cc
index 48d7333..8b9faab 100644
--- a/src/tint/ast/struct_member_offset_attribute.cc
+++ b/src/tint/ast/struct_member_offset_attribute.cc
@@ -25,8 +25,8 @@
StructMemberOffsetAttribute::StructMemberOffsetAttribute(ProgramID pid,
NodeID nid,
const Source& src,
- uint32_t o)
- : Base(pid, nid, src), offset(o) {}
+ const ast::Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
StructMemberOffsetAttribute::~StructMemberOffsetAttribute() = default;
@@ -37,7 +37,8 @@
const StructMemberOffsetAttribute* StructMemberOffsetAttribute::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
- return ctx->dst->create<StructMemberOffsetAttribute>(src, offset);
+ auto expr_ = ctx->Clone(expr);
+ return ctx->dst->create<StructMemberOffsetAttribute>(src, expr_);
}
} // namespace tint::ast
diff --git a/src/tint/ast/struct_member_offset_attribute.h b/src/tint/ast/struct_member_offset_attribute.h
index 790927e..632d71a 100644
--- a/src/tint/ast/struct_member_offset_attribute.h
+++ b/src/tint/ast/struct_member_offset_attribute.h
@@ -18,6 +18,7 @@
#include <string>
#include "src/tint/ast/attribute.h"
+#include "src/tint/ast/expression.h"
namespace tint::ast {
@@ -37,8 +38,11 @@
/// @param pid the identifier of the program that owns this node
/// @param nid the unique node identifier
/// @param src the source of this node
- /// @param offset the offset value
- StructMemberOffsetAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t offset);
+ /// @param expr the offset expression
+ StructMemberOffsetAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const ast::Expression* expr);
~StructMemberOffsetAttribute() override;
/// @returns the WGSL name for the attribute
@@ -50,8 +54,8 @@
/// @return the newly cloned node
const StructMemberOffsetAttribute* Clone(CloneContext* ctx) const override;
- /// The offset value
- const uint32_t offset;
+ /// The offset expression
+ const ast::Expression* const expr;
};
} // namespace tint::ast
diff --git a/src/tint/ast/struct_member_offset_attribute_test.cc b/src/tint/ast/struct_member_offset_attribute_test.cc
index 9d81ffb..23dd697 100644
--- a/src/tint/ast/struct_member_offset_attribute_test.cc
+++ b/src/tint/ast/struct_member_offset_attribute_test.cc
@@ -17,11 +17,13 @@
namespace tint::ast {
namespace {
+using namespace tint::number_suffixes; // NOLINT
using StructMemberOffsetAttributeTest = TestHelper;
TEST_F(StructMemberOffsetAttributeTest, Creation) {
- auto* d = create<StructMemberOffsetAttribute>(2u);
- EXPECT_EQ(2u, d->offset);
+ auto* d = MemberOffset(2_u);
+ ASSERT_TRUE(d->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(2u, d->expr->As<ast::IntLiteralExpression>()->value);
}
} // namespace
diff --git a/src/tint/ast/struct_member_size_attribute.cc b/src/tint/ast/struct_member_size_attribute.cc
index 3919078..833896d 100644
--- a/src/tint/ast/struct_member_size_attribute.cc
+++ b/src/tint/ast/struct_member_size_attribute.cc
@@ -26,8 +26,8 @@
StructMemberSizeAttribute::StructMemberSizeAttribute(ProgramID pid,
NodeID nid,
const Source& src,
- uint32_t sz)
- : Base(pid, nid, src), size(sz) {}
+ const ast::Expression* exp)
+ : Base(pid, nid, src), expr(exp) {}
StructMemberSizeAttribute::~StructMemberSizeAttribute() = default;
@@ -38,7 +38,8 @@
const StructMemberSizeAttribute* StructMemberSizeAttribute::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
- return ctx->dst->create<StructMemberSizeAttribute>(src, size);
+ auto expr_ = ctx->Clone(expr);
+ return ctx->dst->create<StructMemberSizeAttribute>(src, expr_);
}
} // namespace tint::ast
diff --git a/src/tint/ast/struct_member_size_attribute.h b/src/tint/ast/struct_member_size_attribute.h
index 5649e2e..c048b1f 100644
--- a/src/tint/ast/struct_member_size_attribute.h
+++ b/src/tint/ast/struct_member_size_attribute.h
@@ -19,6 +19,7 @@
#include <string>
#include "src/tint/ast/attribute.h"
+#include "src/tint/ast/expression.h"
namespace tint::ast {
@@ -29,8 +30,11 @@
/// @param pid the identifier of the program that owns this node
/// @param nid the unique node identifier
/// @param src the source of this node
- /// @param size the size value
- StructMemberSizeAttribute(ProgramID pid, NodeID nid, const Source& src, uint32_t size);
+ /// @param expr the size expression
+ StructMemberSizeAttribute(ProgramID pid,
+ NodeID nid,
+ const Source& src,
+ const ast::Expression* expr);
~StructMemberSizeAttribute() override;
/// @returns the WGSL name for the attribute
@@ -42,8 +46,8 @@
/// @return the newly cloned node
const StructMemberSizeAttribute* Clone(CloneContext* ctx) const override;
- /// The size value
- const uint32_t size;
+ /// The size expression
+ const ast::Expression* const expr;
};
} // namespace tint::ast
diff --git a/src/tint/ast/struct_member_size_attribute_test.cc b/src/tint/ast/struct_member_size_attribute_test.cc
index a82d53a..9f6b49b 100644
--- a/src/tint/ast/struct_member_size_attribute_test.cc
+++ b/src/tint/ast/struct_member_size_attribute_test.cc
@@ -19,11 +19,13 @@
namespace tint::ast {
namespace {
+using namespace tint::number_suffixes; // NOLINT
using StructMemberSizeAttributeTest = TestHelper;
TEST_F(StructMemberSizeAttributeTest, Creation) {
- auto* d = create<StructMemberSizeAttribute>(2u);
- EXPECT_EQ(2u, d->size);
+ auto* d = MemberSize(2_u);
+ ASSERT_TRUE(d->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(2u, d->expr->As<ast::IntLiteralExpression>()->value);
}
} // namespace
diff --git a/src/tint/ast/struct_member_test.cc b/src/tint/ast/struct_member_test.cc
index 69772f4..850be9f 100644
--- a/src/tint/ast/struct_member_test.cc
+++ b/src/tint/ast/struct_member_test.cc
@@ -18,10 +18,11 @@
namespace tint::ast {
namespace {
+using namespace tint::number_suffixes; // NOLINT
using StructMemberTest = TestHelper;
TEST_F(StructMemberTest, Creation) {
- auto* st = Member("a", ty.i32(), utils::Vector{MemberSize(4)});
+ auto* st = Member("a", ty.i32(), utils::Vector{MemberSize(4_a)});
EXPECT_EQ(st->symbol, Symbol(1, ID()));
EXPECT_TRUE(st->type->Is<ast::I32>());
EXPECT_EQ(st->attributes.Length(), 1u);
@@ -66,7 +67,7 @@
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.Member("a", b.ty.i32(), utils::Vector{b.MemberSize(4), nullptr});
+ b.Member("a", b.ty.i32(), utils::Vector{b.MemberSize(4_a), nullptr});
},
"internal compiler error");
}
@@ -76,7 +77,7 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.Member(b2.Sym("a"), b1.ty.i32(), utils::Vector{b1.MemberSize(4)});
+ b1.Member(b2.Sym("a"), b1.ty.i32(), utils::Vector{b1.MemberSize(4_a)});
},
"internal compiler error");
}
@@ -86,7 +87,7 @@
{
ProgramBuilder b1;
ProgramBuilder b2;
- b1.Member("a", b1.ty.i32(), utils::Vector{b2.MemberSize(4)});
+ b1.Member("a", b1.ty.i32(), utils::Vector{b2.MemberSize(4_a)});
},
"internal compiler error");
}
diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h
index ead13e9..c94f5fb 100644
--- a/src/tint/program_builder.h
+++ b/src/tint/program_builder.h
@@ -2251,25 +2251,37 @@
}
/// Creates a ast::StructMemberOffsetAttribute
- /// @param val the offset value
+ /// @param val the offset expression
/// @returns the offset attribute pointer
- const ast::StructMemberOffsetAttribute* MemberOffset(uint32_t val) {
- return create<ast::StructMemberOffsetAttribute>(source_, val);
+ template <typename EXPR>
+ const ast::StructMemberOffsetAttribute* MemberOffset(EXPR&& val) {
+ return create<ast::StructMemberOffsetAttribute>(source_, Expr(std::forward<EXPR>(val)));
+ }
+
+ /// Creates a ast::StructMemberOffsetAttribute
+ /// @param source the source information
+ /// @param val the offset expression
+ /// @returns the offset attribute pointer
+ template <typename EXPR>
+ const ast::StructMemberOffsetAttribute* MemberOffset(const Source& source, EXPR&& val) {
+ return create<ast::StructMemberOffsetAttribute>(source, Expr(std::forward<EXPR>(val)));
}
/// Creates a ast::StructMemberSizeAttribute
/// @param source the source information
/// @param val the size value
/// @returns the size attribute pointer
- const ast::StructMemberSizeAttribute* MemberSize(const Source& source, uint32_t val) {
- return create<ast::StructMemberSizeAttribute>(source, val);
+ template <typename EXPR>
+ const ast::StructMemberSizeAttribute* MemberSize(const Source& source, EXPR&& val) {
+ return create<ast::StructMemberSizeAttribute>(source, Expr(std::forward<EXPR>(val)));
}
/// Creates a ast::StructMemberSizeAttribute
/// @param val the size value
/// @returns the size attribute pointer
- const ast::StructMemberSizeAttribute* MemberSize(uint32_t val) {
- return create<ast::StructMemberSizeAttribute>(source_, val);
+ template <typename EXPR>
+ const ast::StructMemberSizeAttribute* MemberSize(EXPR&& val) {
+ return create<ast::StructMemberSizeAttribute>(source_, Expr(std::forward<EXPR>(val)));
}
/// Creates a ast::StructMemberAlignAttribute
@@ -2525,7 +2537,7 @@
const ast::StructMember* Member(uint32_t offset, NAME&& name, const ast::Type* type) {
return create<ast::StructMember>(source_, Sym(std::forward<NAME>(name)), type,
utils::Vector<const ast::Attribute*, 1>{
- create<ast::StructMemberOffsetAttribute>(offset),
+ MemberOffset(AInt(offset)),
});
}
diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc
index 37ea5ca..caa2c69 100644
--- a/src/tint/reader/spirv/parser_impl.cc
+++ b/src/tint/reader/spirv/parser_impl.cc
@@ -464,7 +464,7 @@
return {};
}
return {
- create<ast::StructMemberOffsetAttribute>(Source{}, decoration[1]),
+ builder_.MemberOffset(Source{}, AInt(decoration[1])),
};
case SpvDecorationNonReadable:
// WGSL doesn't have a member decoration for this. Silently drop it.
diff --git a/src/tint/reader/spirv/parser_impl_convert_member_decoration_test.cc b/src/tint/reader/spirv/parser_impl_convert_member_decoration_test.cc
index cd3e7b5..767b333 100644
--- a/src/tint/reader/spirv/parser_impl_convert_member_decoration_test.cc
+++ b/src/tint/reader/spirv/parser_impl_convert_member_decoration_test.cc
@@ -54,7 +54,8 @@
EXPECT_TRUE(result[0]->Is<ast::StructMemberOffsetAttribute>());
auto* offset_deco = result[0]->As<ast::StructMemberOffsetAttribute>();
ASSERT_NE(offset_deco, nullptr);
- EXPECT_EQ(offset_deco->offset, 8u);
+ ASSERT_TRUE(offset_deco->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(offset_deco->expr->As<ast::IntLiteralExpression>()->value, 8u);
EXPECT_TRUE(p->error().empty());
}
diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc
index f3b424e..2c72d6a 100644
--- a/src/tint/reader/wgsl/parser_impl.cc
+++ b/src/tint/reader/wgsl/parser_impl.cc
@@ -3566,7 +3566,7 @@
}
match(Token::Type::kComma);
- return create<ast::StructMemberSizeAttribute>(t.source(), val.value);
+ return builder_.MemberSize(t.source(), AInt(val.value));
});
}
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
index cc6d186..9bafaf0 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_attribute_test.cc
@@ -30,7 +30,8 @@
ASSERT_TRUE(member_attr->Is<ast::StructMemberSizeAttribute>());
auto* o = member_attr->As<ast::StructMemberSizeAttribute>();
- EXPECT_EQ(o->size, 4u);
+ ASSERT_TRUE(o->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(o->expr->As<ast::IntLiteralExpression>()->value, 4u);
}
TEST_F(ParserImplTest, Attribute_Size_TrailingComma) {
@@ -46,7 +47,8 @@
ASSERT_TRUE(member_attr->Is<ast::StructMemberSizeAttribute>());
auto* o = member_attr->As<ast::StructMemberSizeAttribute>();
- EXPECT_EQ(o->size, 4u);
+ ASSERT_TRUE(o->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(o->expr->As<ast::IntLiteralExpression>()->value, 4u);
}
TEST_F(ParserImplTest, Attribute_Size_MissingLeftParen) {
diff --git a/src/tint/reader/wgsl/parser_impl_struct_member_test.cc b/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
index fe233fe..6267406 100644
--- a/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
+++ b/src/tint/reader/wgsl/parser_impl_struct_member_test.cc
@@ -73,8 +73,11 @@
EXPECT_EQ(m->symbol, builder.Symbols().Get("a"));
EXPECT_TRUE(m->type->Is<ast::I32>());
EXPECT_EQ(m->attributes.Length(), 1u);
- EXPECT_TRUE(m->attributes[0]->Is<ast::StructMemberSizeAttribute>());
- EXPECT_EQ(m->attributes[0]->As<ast::StructMemberSizeAttribute>()->size, 2u);
+ ASSERT_TRUE(m->attributes[0]->Is<ast::StructMemberSizeAttribute>());
+ auto* s = m->attributes[0]->As<ast::StructMemberSizeAttribute>();
+
+ ASSERT_TRUE(s->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(s->expr->As<ast::IntLiteralExpression>()->value, 2u);
EXPECT_EQ(m->source.range, (Source::Range{{1u, 10u}, {1u, 11u}}));
EXPECT_EQ(m->type->source.range, (Source::Range{{1u, 14u}, {1u, 17u}}));
@@ -95,7 +98,9 @@
EXPECT_TRUE(m->type->Is<ast::I32>());
EXPECT_EQ(m->attributes.Length(), 2u);
ASSERT_TRUE(m->attributes[0]->Is<ast::StructMemberSizeAttribute>());
- EXPECT_EQ(m->attributes[0]->As<ast::StructMemberSizeAttribute>()->size, 2u);
+ auto* size_attr = m->attributes[0]->As<ast::StructMemberSizeAttribute>();
+ ASSERT_TRUE(size_attr->expr->Is<ast::IntLiteralExpression>());
+ EXPECT_EQ(size_attr->expr->As<ast::IntLiteralExpression>()->value, 2u);
ASSERT_TRUE(m->attributes[1]->Is<ast::StructMemberAlignAttribute>());
auto* attr = m->attributes[1]->As<ast::StructMemberAlignAttribute>();
diff --git a/src/tint/resolver/attribute_validation_test.cc b/src/tint/resolver/attribute_validation_test.cc
index b182f6d..e648562 100644
--- a/src/tint/resolver/attribute_validation_test.cc
+++ b/src/tint/resolver/attribute_validation_test.cc
@@ -106,9 +106,9 @@
case AttributeKind::kLocation:
return {builder.Location(source, 1_a)};
case AttributeKind::kOffset:
- return {builder.create<ast::StructMemberOffsetAttribute>(source, 4u)};
+ return {builder.MemberOffset(source, 4_a)};
case AttributeKind::kSize:
- return {builder.create<ast::StructMemberSizeAttribute>(source, 16u)};
+ return {builder.MemberSize(source, 16_a)};
case AttributeKind::kStage:
return {builder.Stage(source, ast::PipelineStage::kCompute)};
case AttributeKind::kStride:
diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc
index 27dc219..911b66c 100644
--- a/src/tint/resolver/resolver.cc
+++ b/src/tint/resolver/resolver.cc
@@ -2803,11 +2803,22 @@
if (auto* o = attr->As<ast::StructMemberOffsetAttribute>()) {
// Offset attributes are not part of the WGSL spec, but are emitted
// by the SPIR-V reader.
- if (o->offset < struct_size) {
+
+ auto* materialized = Materialize(Expression(o->expr));
+ if (!materialized) {
+ return nullptr;
+ }
+ auto const_value = materialized->ConstantValue();
+ if (!const_value) {
+ AddError("'offset' must be constant expression", o->expr->source);
+ return nullptr;
+ }
+ offset = const_value->As<uint64_t>();
+
+ if (offset < struct_size) {
AddError("offsets must be in ascending order", o->source);
return nullptr;
}
- offset = o->offset;
align = 1;
has_offset_attr = true;
} else if (auto* a = attr->As<ast::StructMemberAlignAttribute>()) {
@@ -2829,13 +2840,24 @@
align = const_value->As<u32>();
has_align_attr = true;
} else if (auto* s = attr->As<ast::StructMemberSizeAttribute>()) {
- if (s->size < size) {
+ auto* materialized = Materialize(Expression(s->expr));
+ if (!materialized) {
+ return nullptr;
+ }
+ auto const_value = materialized->ConstantValue();
+ if (!const_value) {
+ AddError("'size' must be constant expression", s->expr->source);
+ return nullptr;
+ }
+ auto value = const_value->As<uint64_t>();
+
+ if (value < size) {
AddError("size must be at least as big as the type's size (" +
std::to_string(size) + ")",
s->source);
return nullptr;
}
- size = s->size;
+ size = const_value->As<u32>();
has_size_attr = true;
} else if (auto* l = attr->As<ast::LocationAttribute>()) {
auto* materialize = Materialize(Expression(l->expr));
diff --git a/src/tint/resolver/storage_class_layout_validation_test.cc b/src/tint/resolver/storage_class_layout_validation_test.cc
index 8f1b2c1..91d0d5d 100644
--- a/src/tint/resolver/storage_class_layout_validation_test.cc
+++ b/src/tint/resolver/storage_class_layout_validation_test.cc
@@ -35,7 +35,7 @@
Structure(Source{{12, 34}}, "S",
utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
+ Member("a", ty.f32(), utils::Vector{MemberSize(5_a)}),
Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(1_u)}),
});
@@ -65,7 +65,7 @@
Structure(Source{{12, 34}}, "S",
utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
+ Member("a", ty.f32(), utils::Vector{MemberSize(5_a)}),
Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(4_u)}),
});
@@ -227,7 +227,7 @@
Structure(Source{{12, 34}}, "Inner",
utils::Vector{
- Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5)}),
+ Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5_a)}),
});
Structure(Source{{34, 56}}, "Outer",
@@ -279,7 +279,7 @@
Member("a", ty.i32()),
Member("b", ty.i32()),
Member("c", ty.i32()),
- Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5)}),
+ Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5_a)}),
});
Structure(Source{{34, 56}}, "Outer",
@@ -327,7 +327,7 @@
Structure(Source{{12, 34}}, "Inner",
utils::Vector{
- Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5)}),
+ Member("scalar", ty.i32(), utils::Vector{MemberAlign(1_u), MemberSize(5_a)}),
});
Structure(Source{{34, 56}}, "Outer",
@@ -550,7 +550,7 @@
Enable(ast::Extension::kChromiumExperimentalPushConstant);
Structure(
Source{{12, 34}}, "S",
- utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
+ utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(5_a)}),
Member(Source{{34, 56}}, "b", ty.f32(), utils::Vector{MemberAlign(1_u)})});
GlobalVar(Source{{78, 90}}, "a", ty.type_name("S"), ast::StorageClass::kPushConstant);
@@ -575,7 +575,7 @@
// };
// var<push_constant> a : S;
Enable(ast::Extension::kChromiumExperimentalPushConstant);
- Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(5)}),
+ Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(5_a)}),
Member("b", ty.f32(), utils::Vector{MemberAlign(4_u)})});
GlobalVar("a", ty.type_name("S"), ast::StorageClass::kPushConstant);
diff --git a/src/tint/resolver/storage_class_validation_test.cc b/src/tint/resolver/storage_class_validation_test.cc
index 8c1a253..cff20c0 100644
--- a/src/tint/resolver/storage_class_validation_test.cc
+++ b/src/tint/resolver/storage_class_validation_test.cc
@@ -400,7 +400,7 @@
Enable(ast::Extension::kF16);
auto* s = Structure(
- "S", utils::Vector{Member("a", ty.f16(Source{{56, 78}}), utils::Vector{MemberSize(16)})});
+ "S", utils::Vector{Member("a", ty.f16(Source{{56, 78}}), utils::Vector{MemberSize(16_a)})});
auto* a = ty.array(ty.Of(s), 3_u);
GlobalVar("g", a, ast::StorageClass::kUniform, Binding(0_a), Group(0_a));
@@ -474,7 +474,7 @@
// @size(16) f : f32;
// }
// var<uniform> g : array<S, 3u>;
- auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(16)})});
+ auto* s = Structure("S", utils::Vector{Member("a", ty.f32(), utils::Vector{MemberSize(16_a)})});
auto* a = ty.array(ty.Of(s), 3_u);
GlobalVar(Source{{56, 78}}, "g", a, ast::StorageClass::kUniform, Binding(0_a), Group(0_a));
diff --git a/src/tint/resolver/struct_layout_test.cc b/src/tint/resolver/struct_layout_test.cc
index ae3972b..5b5ab68 100644
--- a/src/tint/resolver/struct_layout_test.cc
+++ b/src/tint/resolver/struct_layout_test.cc
@@ -460,15 +460,15 @@
TEST_F(ResolverStructLayoutTest, SizeAttributes) {
auto* inner = Structure("Inner", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberSize(8)}),
- Member("b", ty.f32(), utils::Vector{MemberSize(16)}),
- Member("c", ty.f32(), utils::Vector{MemberSize(8)}),
+ Member("a", ty.f32(), utils::Vector{MemberSize(8_a)}),
+ Member("b", ty.f32(), utils::Vector{MemberSize(16_a)}),
+ Member("c", ty.f32(), utils::Vector{MemberSize(8_a)}),
});
auto* s = Structure("S", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberSize(4)}),
- Member("b", ty.u32(), utils::Vector{MemberSize(8)}),
+ Member("a", ty.f32(), utils::Vector{MemberSize(4_a)}),
+ Member("b", ty.u32(), utils::Vector{MemberSize(8_a)}),
Member("c", ty.Of(inner)),
- Member("d", ty.i32(), utils::Vector{MemberSize(32)}),
+ Member("d", ty.i32(), utils::Vector{MemberSize(32_a)}),
});
ASSERT_TRUE(r()->Resolve()) << r()->error();
diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc
index 6054cac..90b1664 100644
--- a/src/tint/resolver/validation_test.cc
+++ b/src/tint/resolver/validation_test.cc
@@ -1257,7 +1257,7 @@
TEST_F(ResolverValidationTest, ZeroStructMemberSizeAttribute) {
Structure("S", utils::Vector{
- Member("a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, 0)}),
+ Member("a", ty.f32(), utils::Vector{MemberSize(Source{{12, 34}}, 0_a)}),
});
EXPECT_FALSE(r()->Resolve());
@@ -1267,7 +1267,7 @@
TEST_F(ResolverValidationTest, OffsetAndSizeAttribute) {
Structure("S", utils::Vector{
Member(Source{{12, 34}}, "a", ty.f32(),
- utils::Vector{MemberOffset(0), MemberSize(4)}),
+ utils::Vector{MemberOffset(0_a), MemberSize(4_a)}),
});
EXPECT_FALSE(r()->Resolve());
@@ -1279,7 +1279,7 @@
TEST_F(ResolverValidationTest, OffsetAndAlignAttribute) {
Structure("S", utils::Vector{
Member(Source{{12, 34}}, "a", ty.f32(),
- utils::Vector{MemberOffset(0), MemberAlign(4_u)}),
+ utils::Vector{MemberOffset(0_a), MemberAlign(4_u)}),
});
EXPECT_FALSE(r()->Resolve());
@@ -1291,7 +1291,7 @@
TEST_F(ResolverValidationTest, OffsetAndAlignAndSizeAttribute) {
Structure("S", utils::Vector{
Member(Source{{12, 34}}, "a", ty.f32(),
- utils::Vector{MemberOffset(0), MemberAlign(4_u), MemberSize(4)}),
+ utils::Vector{MemberOffset(0_a), MemberAlign(4_u), MemberSize(4_a)}),
});
EXPECT_FALSE(r()->Resolve());
diff --git a/src/tint/transform/array_length_from_uniform.cc b/src/tint/transform/array_length_from_uniform.cc
index 4f264c4..71a5cca 100644
--- a/src/tint/transform/array_length_from_uniform.cc
+++ b/src/tint/transform/array_length_from_uniform.cc
@@ -21,6 +21,7 @@
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
+#include "src/tint/sem/statement.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/simplify_pointers.h"
@@ -33,65 +34,79 @@
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
-/// Iterate over all arrayLength() builtins that operate on
-/// storage buffer variables.
-/// @param ctx the CloneContext.
-/// @param functor of type void(const ast::CallExpression*, const
-/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
-/// ast::CallExpression of the arrayLength call expression node, a
-/// sem::VariableUser of the used storage buffer variable, and the
-/// sem::GlobalVariable for the storage buffer.
-template <typename F>
-static void IterateArrayLengthOnStorageVar(CloneContext& ctx, F&& functor) {
- auto& sem = ctx.src->Sem();
+/// The PIMPL state for this transform
+struct ArrayLengthFromUniform::State {
+ /// The clone context
+ CloneContext& ctx;
- // Find all calls to the arrayLength() builtin.
- for (auto* node : ctx.src->ASTNodes().Objects()) {
- auto* call_expr = node->As<ast::CallExpression>();
- if (!call_expr) {
- continue;
- }
+ /// Iterate over all arrayLength() builtins that operate on
+ /// storage buffer variables.
+ /// @param functor of type void(const ast::CallExpression*, const
+ /// sem::VariableUser, const sem::GlobalVariable*). It takes in an
+ /// ast::CallExpression of the arrayLength call expression node, a
+ /// sem::VariableUser of the used storage buffer variable, and the
+ /// sem::GlobalVariable for the storage buffer.
+ template <typename F>
+ void IterateArrayLengthOnStorageVar(F&& functor) {
+ auto& sem = ctx.src->Sem();
- auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
- auto* builtin = call->Target()->As<sem::Builtin>();
- if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
- continue;
- }
+ // Find all calls to the arrayLength() builtin.
+ for (auto* node : ctx.src->ASTNodes().Objects()) {
+ auto* call_expr = node->As<ast::CallExpression>();
+ if (!call_expr) {
+ continue;
+ }
- // Get the storage buffer that contains the runtime array.
- // Since we require SimplifyPointers, we can assume that the arrayLength()
- // call has one of two forms:
- // arrayLength(&struct_var.array_member)
- // arrayLength(&array_var)
- auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
- if (!param || param->op != ast::UnaryOp::kAddressOf) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be &array_var or "
- "&struct_var.array_member";
- break;
- }
- auto* storage_buffer_expr = param->expr;
- if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
- storage_buffer_expr = accessor->structure;
- }
- auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
- if (!storage_buffer_sem) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "expected form of arrayLength argument to be &array_var or "
- "&struct_var.array_member";
- break;
- }
+ auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
+ auto* builtin = call->Target()->As<sem::Builtin>();
+ if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) {
+ continue;
+ }
- // Get the index to use for the buffer size array.
- auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
- if (!var) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "storage buffer is not a global variable";
- break;
+ if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
+ if (call_stmt->expr == call_expr) {
+ // arrayLength() is used as a statement.
+ // The argument expression must be side-effect free, so just drop the statement.
+ RemoveStatement(ctx, call_stmt);
+ continue;
+ }
+ }
+
+ // Get the storage buffer that contains the runtime array.
+ // Since we require SimplifyPointers, we can assume that the arrayLength()
+ // call has one of two forms:
+ // arrayLength(&struct_var.array_member)
+ // arrayLength(&array_var)
+ auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
+ if (!param || param->op != ast::UnaryOp::kAddressOf) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+ auto* storage_buffer_expr = param->expr;
+ if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
+ storage_buffer_expr = accessor->structure;
+ }
+ auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
+ if (!storage_buffer_sem) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "expected form of arrayLength argument to be &array_var or "
+ "&struct_var.array_member";
+ break;
+ }
+
+ // Get the index to use for the buffer size array.
+ auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
+ if (!var) {
+ TINT_ICE(Transform, ctx.dst->Diagnostics())
+ << "storage buffer is not a global variable";
+ break;
+ }
+ functor(call_expr, storage_buffer_sem, var);
}
- functor(call_expr, storage_buffer_sem, var);
}
-}
+};
bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
@@ -119,17 +134,17 @@
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
- IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression*, const sem::VariableUser*,
- const sem::GlobalVariable* var) {
- auto binding = var->BindingPoint();
- auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
- if (idx_itr == cfg->bindpoint_to_size_index.end()) {
- return;
- }
- if (idx_itr->second > max_buffer_size_index) {
- max_buffer_size_index = idx_itr->second;
- }
- });
+ State{ctx}.IterateArrayLengthOnStorageVar(
+ [&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
+ auto binding = var->BindingPoint();
+ auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
+ if (idx_itr == cfg->bindpoint_to_size_index.end()) {
+ return;
+ }
+ if (idx_itr->second > max_buffer_size_index) {
+ max_buffer_size_index = idx_itr->second;
+ }
+ });
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
@@ -156,9 +171,9 @@
std::unordered_set<uint32_t> used_size_indices;
- IterateArrayLengthOnStorageVar(ctx, [&](const ast::CallExpression* call_expr,
- const sem::VariableUser* storage_buffer_sem,
- const sem::GlobalVariable* var) {
+ State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
+ const sem::VariableUser* storage_buffer_sem,
+ const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
diff --git a/src/tint/transform/array_length_from_uniform.h b/src/tint/transform/array_length_from_uniform.h
index c34c529..8bd6af5 100644
--- a/src/tint/transform/array_length_from_uniform.h
+++ b/src/tint/transform/array_length_from_uniform.h
@@ -113,6 +113,10 @@
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
+
+ private:
+ /// The PIMPL state for this transform
+ struct State;
};
} // namespace tint::transform
diff --git a/src/tint/transform/array_length_from_uniform_test.cc b/src/tint/transform/array_length_from_uniform_test.cc
index 6663666..1058bf1 100644
--- a/src/tint/transform/array_length_from_uniform_test.cc
+++ b/src/tint/transform/array_length_from_uniform_test.cc
@@ -496,5 +496,43 @@
got.data.Get<ArrayLengthFromUniform::Result>()->used_size_indices);
}
+TEST_F(ArrayLengthFromUniformTest, CallStatement) {
+ auto* src = R"(
+struct SB {
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> a : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ arrayLength(&a.arr);
+}
+)";
+
+ auto* expect =
+ R"(
+struct SB {
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> a : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ ArrayLengthFromUniform::Config cfg({0, 30u});
+ cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
+
+ DataMap data;
+ data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
+
+ auto got = Run<Unshadow, SimplifyPointers, ArrayLengthFromUniform>(src, data);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::transform
diff --git a/src/tint/transform/calculate_array_length.cc b/src/tint/transform/calculate_array_length.cc
index 8edcbf5..b06f7b6 100644
--- a/src/tint/transform/calculate_array_length.cc
+++ b/src/tint/transform/calculate_array_length.cc
@@ -130,6 +130,16 @@
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// We're dealing with an arrayLength() call
+ if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
+ if (call_stmt->expr == call_expr) {
+ // arrayLength() is used as a statement.
+ // The argument expression must be side-effect free, so just drop the
+ // statement.
+ RemoveStatement(ctx, call_stmt);
+ continue;
+ }
+ }
+
// A runtime-sized array can only appear as the store type of a variable, or the
// last element of a structure (which cannot itself be nested). Given that we
// require SimplifyPointers, we can assume that the arrayLength() call has one
diff --git a/src/tint/transform/calculate_array_length_test.cc b/src/tint/transform/calculate_array_length_test.cc
index 98f2ed7..826ffd6 100644
--- a/src/tint/transform/calculate_array_length_test.cc
+++ b/src/tint/transform/calculate_array_length_test.cc
@@ -547,5 +547,37 @@
EXPECT_EQ(expect, str(got));
}
+TEST_F(CalculateArrayLengthTest, CallStatement) {
+ auto* src = R"(
+struct SB {
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> a : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+ arrayLength(&a.arr);
+}
+)";
+
+ auto* expect =
+ R"(
+struct SB {
+ arr : array<i32>,
+}
+
+@group(0) @binding(0) var<storage, read> a : SB;
+
+@compute @workgroup_size(1)
+fn main() {
+}
+)";
+
+ auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
+
+ EXPECT_EQ(expect, str(got));
+}
+
} // namespace
} // namespace tint::transform
diff --git a/src/tint/transform/decompose_strided_array.cc b/src/tint/transform/decompose_strided_array.cc
index 61841f6..b7fe53c 100644
--- a/src/tint/transform/decompose_strided_array.cc
+++ b/src/tint/transform/decompose_strided_array.cc
@@ -73,7 +73,7 @@
auto* member_ty = ctx.Clone(ast->type);
auto* member = ctx.dst->Member(kMemberName, member_ty,
utils::Vector{
- ctx.dst->MemberSize(arr->Stride()),
+ ctx.dst->MemberSize(AInt(arr->Stride())),
});
ctx.dst->Structure(name, utils::Vector{member});
return name;
diff --git a/src/tint/transform/decompose_strided_matrix_test.cc b/src/tint/transform/decompose_strided_matrix_test.cc
index c7a8f59..6bfba7c 100644
--- a/src/tint/transform/decompose_strided_matrix_test.cc
+++ b/src/tint/transform/decompose_strided_matrix_test.cc
@@ -71,7 +71,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(16u),
+ b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -127,7 +127,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(16u),
+ b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -180,7 +180,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(16u),
+ b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(8u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -233,7 +233,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(8u),
+ b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -290,7 +290,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(16u),
+ b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -344,7 +344,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(8u),
+ b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -402,7 +402,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(8u),
+ b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -461,7 +461,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(8u),
+ b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -532,7 +532,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(8u),
+ b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
@@ -585,7 +585,7 @@
"S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
- b.create<ast::StructMemberOffsetAttribute>(8u),
+ b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
}),
diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc
index f46fdea..e6f070b 100644
--- a/src/tint/transform/std140.cc
+++ b/src/tint/transform/std140.cc
@@ -246,9 +246,9 @@
// The matrix was @size() annotated with a larger size than the
// natural size for the matrix. This extra padding needs to be
// applied to the last column vector.
- attributes.Push(
- b.MemberSize(member->Size() - mat->ColumnType()->Size() *
- (num_columns - 1)));
+ attributes.Push(b.MemberSize(
+ AInt(member->Size() -
+ mat->ColumnType()->Size() * (num_columns - 1))));
}
// Build the member
diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc
index 132f542..3f98999 100644
--- a/src/tint/writer/glsl/generator_impl.cc
+++ b/src/tint/writer/glsl/generator_impl.cc
@@ -1884,7 +1884,12 @@
}
},
[&](const ast::Let* let) { return EmitProgramConstVariable(let); },
- [&](const ast::Override* override) { return EmitOverride(override); },
+ [&](const ast::Override*) {
+ // Override is removed with SubstituteOverride
+ TINT_ICE(Writer, diagnostics_)
+ << "Override should have been removed by the substitute_override transform.";
+ return false;
+ },
[&](const ast::Const*) {
return true; // Constants are embedded at their use
},
@@ -2985,38 +2990,6 @@
return true;
}
-bool GeneratorImpl::EmitOverride(const ast::Override* override) {
- auto* sem = builder_.Sem().Get(override);
- auto* type = sem->Type();
-
- auto* global = sem->As<sem::GlobalVariable>();
- auto override_id = global->OverrideId();
-
- line() << "#ifndef " << kSpecConstantPrefix << override_id.value;
-
- if (override->constructor != nullptr) {
- auto out = line();
- out << "#define " << kSpecConstantPrefix << override_id.value << " ";
- if (!EmitExpression(out, override->constructor)) {
- return false;
- }
- } else {
- line() << "#error spec constant required for constant id " << override_id.value;
- }
- line() << "#endif";
- {
- auto out = line();
- out << "const ";
- if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
- builder_.Symbols().NameFor(override->symbol))) {
- return false;
- }
- out << " = " << kSpecConstantPrefix << override_id.value << ";";
- }
-
- return true;
-}
-
template <typename F>
bool GeneratorImpl::CallBuiltinHelper(std::ostream& out,
const ast::CallExpression* call,
diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h
index e70bdc2..cab6881 100644
--- a/src/tint/writer/glsl/generator_impl.h
+++ b/src/tint/writer/glsl/generator_impl.h
@@ -457,10 +457,6 @@
/// @param let the 'let' to emit
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* let);
- /// Handles generating a module-scope 'override' declaration
- /// @param override the 'override' to emit
- /// @returns true if the variable was emitted
- bool EmitOverride(const ast::Override* override);
/// Handles generating a builtin method name
/// @param builtin the semantic info for the builtin
/// @returns the name or "" if not valid
diff --git a/src/tint/writer/glsl/generator_impl_function_test.cc b/src/tint/writer/glsl/generator_impl_function_test.cc
index fd74e2d..473c30c 100644
--- a/src/tint/writer/glsl/generator_impl_function_test.cc
+++ b/src/tint/writer/glsl/generator_impl_function_test.cc
@@ -783,41 +783,6 @@
)");
}
-TEST_F(GlslGeneratorImplTest_Function,
- Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
- Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
- Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
- Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
- Func("main", utils::Empty, ty.void_(), {},
- utils::Vector{
- Stage(ast::PipelineStage::kCompute),
- WorkgroupSize("width", "height", "depth"),
- });
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.Generate()) << gen.error();
- EXPECT_EQ(gen.result(), R"(#version 310 es
-
-#ifndef WGSL_SPEC_CONSTANT_7
-#define WGSL_SPEC_CONSTANT_7 2
-#endif
-const int width = WGSL_SPEC_CONSTANT_7;
-#ifndef WGSL_SPEC_CONSTANT_8
-#define WGSL_SPEC_CONSTANT_8 3
-#endif
-const int height = WGSL_SPEC_CONSTANT_8;
-#ifndef WGSL_SPEC_CONSTANT_9
-#define WGSL_SPEC_CONSTANT_9 4
-#endif
-const int depth = WGSL_SPEC_CONSTANT_9;
-layout(local_size_x = WGSL_SPEC_CONSTANT_7, local_size_y = WGSL_SPEC_CONSTANT_8, local_size_z = WGSL_SPEC_CONSTANT_9) in;
-void main() {
- return;
-}
-)");
-}
-
TEST_F(GlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
Func("my_func", utils::Vector{Param("a", ty.array<f32, 5>())}, ty.void_(),
utils::Vector{
diff --git a/src/tint/writer/glsl/generator_impl_module_constant_test.cc b/src/tint/writer/glsl/generator_impl_module_constant_test.cc
index 26dbf2e..b9da4d5 100644
--- a/src/tint/writer/glsl/generator_impl_module_constant_test.cc
+++ b/src/tint/writer/glsl/generator_impl_module_constant_test.cc
@@ -345,50 +345,5 @@
)");
}
-TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_Override) {
- auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23_a));
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
- EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
-#define WGSL_SPEC_CONSTANT_23 3.0f
-#endif
-const float pos = WGSL_SPEC_CONSTANT_23;
-)");
-}
-
-TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) {
- auto* var = Override("pos", ty.f32(), Id(23_a));
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
- EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
-#error spec constant required for constant id 23
-#endif
-const float pos = WGSL_SPEC_CONSTANT_23;
-)");
-}
-
-TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) {
- auto* a = Override("a", ty.f32(), Expr(3_f), Id(0_a));
- auto* b = Override("b", ty.f32(), Expr(2_f));
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(a)) << gen.error();
- ASSERT_TRUE(gen.EmitOverride(b)) << gen.error();
- EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0
-#define WGSL_SPEC_CONSTANT_0 3.0f
-#endif
-const float a = WGSL_SPEC_CONSTANT_0;
-#ifndef WGSL_SPEC_CONSTANT_1
-#define WGSL_SPEC_CONSTANT_1 2.0f
-#endif
-const float b = WGSL_SPEC_CONSTANT_1;
-)");
-}
-
} // namespace
} // namespace tint::writer::glsl
diff --git a/src/tint/writer/glsl/generator_impl_type_test.cc b/src/tint/writer/glsl/generator_impl_type_test.cc
index 15c1398..c2d9109 100644
--- a/src/tint/writer/glsl/generator_impl_type_test.cc
+++ b/src/tint/writer/glsl/generator_impl_type_test.cc
@@ -212,8 +212,8 @@
TEST_F(GlslGeneratorImplTest_Type, EmitType_Struct_WithOffsetAttributes) {
auto* s = Structure("S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberOffset(0)}),
- Member("b", ty.f32(), utils::Vector{MemberOffset(8)}),
+ Member("a", ty.i32(), utils::Vector{MemberOffset(0_a)}),
+ Member("b", ty.f32(), utils::Vector{MemberOffset(8_a)}),
});
GlobalVar("g", ty.Of(s), ast::StorageClass::kPrivate);
diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc
index b05db3e..802c5a7 100644
--- a/src/tint/writer/hlsl/generator_impl.cc
+++ b/src/tint/writer/hlsl/generator_impl.cc
@@ -2846,7 +2846,12 @@
}
}
},
- [&](const ast::Override* override) { return EmitOverride(override); },
+ [&](const ast::Override*) {
+ // Override is removed with SubstituteOverride
+ TINT_ICE(Writer, diagnostics_)
+ << "Override should have been removed by the substitute_override transform.";
+ return false;
+ },
[&](const ast::Const*) {
return true; // Constants are embedded at their use
},
@@ -4092,36 +4097,6 @@
return true;
}
-bool GeneratorImpl::EmitOverride(const ast::Override* override) {
- auto* sem = builder_.Sem().Get(override);
- auto* type = sem->Type();
-
- auto override_id = sem->OverrideId();
-
- line() << "#ifndef " << kSpecConstantPrefix << override_id.value;
-
- if (override->constructor != nullptr) {
- auto out = line();
- out << "#define " << kSpecConstantPrefix << override_id.value << " ";
- if (!EmitExpression(out, override->constructor)) {
- return false;
- }
- } else {
- line() << "#error spec constant required for constant id " << override_id.value;
- }
- line() << "#endif";
- {
- auto out = line();
- out << "static const ";
- if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
- builder_.Symbols().NameFor(override->symbol))) {
- return false;
- }
- out << " = " << kSpecConstantPrefix << override_id.value << ";";
- }
- return true;
-}
-
template <typename F>
bool GeneratorImpl::CallBuiltinHelper(std::ostream& out,
const ast::CallExpression* call,
diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h
index 78680c8..abbf818 100644
--- a/src/tint/writer/hlsl/generator_impl.h
+++ b/src/tint/writer/hlsl/generator_impl.h
@@ -449,10 +449,6 @@
/// @param let the variable to generate
/// @returns true if the variable was emitted
bool EmitLet(const ast::Let* let);
- /// Handles generating a module-scope 'override' declaration
- /// @param override the 'override' to emit
- /// @returns true if the variable was emitted
- bool EmitOverride(const ast::Override* override);
/// Emits call to a helper vector assignment function for the input assignment
/// statement and vector type. This is used to work around FXC issues where
/// assignments to vectors with dynamic indices cause compilation failures.
diff --git a/src/tint/writer/hlsl/generator_impl_function_test.cc b/src/tint/writer/hlsl/generator_impl_function_test.cc
index bcd1891..5c167fa 100644
--- a/src/tint/writer/hlsl/generator_impl_function_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_function_test.cc
@@ -712,40 +712,6 @@
)");
}
-TEST_F(HlslGeneratorImplTest_Function,
- Emit_Attribute_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
- Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
- Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
- Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
- Func("main", utils::Empty, ty.void_(), utils::Empty,
- utils::Vector{
- Stage(ast::PipelineStage::kCompute),
- WorkgroupSize("width", "height", "depth"),
- });
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.Generate()) << gen.error();
- EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_7
-#define WGSL_SPEC_CONSTANT_7 2
-#endif
-static const int width = WGSL_SPEC_CONSTANT_7;
-#ifndef WGSL_SPEC_CONSTANT_8
-#define WGSL_SPEC_CONSTANT_8 3
-#endif
-static const int height = WGSL_SPEC_CONSTANT_8;
-#ifndef WGSL_SPEC_CONSTANT_9
-#define WGSL_SPEC_CONSTANT_9 4
-#endif
-static const int depth = WGSL_SPEC_CONSTANT_9;
-
-[numthreads(WGSL_SPEC_CONSTANT_7, WGSL_SPEC_CONSTANT_8, WGSL_SPEC_CONSTANT_9)]
-void main() {
- return;
-}
-)");
-}
-
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
Func("my_func",
utils::Vector{
diff --git a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc
index 58fdbf1..57dc4b8 100644
--- a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc
@@ -242,50 +242,5 @@
)");
}
-TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override) {
- auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23_a));
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
- EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
-#define WGSL_SPEC_CONSTANT_23 3.0f
-#endif
-static const float pos = WGSL_SPEC_CONSTANT_23;
-)");
-}
-
-TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) {
- auto* var = Override("pos", ty.f32(), Id(23_a));
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
- EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
-#error spec constant required for constant id 23
-#endif
-static const float pos = WGSL_SPEC_CONSTANT_23;
-)");
-}
-
-TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) {
- auto* a = Override("a", ty.f32(), Expr(3_f), Id(0_a));
- auto* b = Override("b", ty.f32(), Expr(2_f));
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(a)) << gen.error();
- ASSERT_TRUE(gen.EmitOverride(b)) << gen.error();
- EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0
-#define WGSL_SPEC_CONSTANT_0 3.0f
-#endif
-static const float a = WGSL_SPEC_CONSTANT_0;
-#ifndef WGSL_SPEC_CONSTANT_1
-#define WGSL_SPEC_CONSTANT_1 2.0f
-#endif
-static const float b = WGSL_SPEC_CONSTANT_1;
-)");
-}
-
} // namespace
} // namespace tint::writer::hlsl
diff --git a/src/tint/writer/hlsl/generator_impl_type_test.cc b/src/tint/writer/hlsl/generator_impl_type_test.cc
index 279eb3f..75dac72 100644
--- a/src/tint/writer/hlsl/generator_impl_type_test.cc
+++ b/src/tint/writer/hlsl/generator_impl_type_test.cc
@@ -221,8 +221,8 @@
TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct_WithOffsetAttributes) {
auto* s = Structure("S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberOffset(0)}),
- Member("b", ty.f32(), utils::Vector{MemberOffset(8)}),
+ Member("a", ty.i32(), utils::Vector{MemberOffset(0_a)}),
+ Member("b", ty.f32(), utils::Vector{MemberOffset(8_a)}),
});
GlobalVar("g", ty.Of(s), ast::StorageClass::kPrivate);
diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc
index 70bdb61..4bc68e9 100644
--- a/src/tint/writer/msl/generator_impl.cc
+++ b/src/tint/writer/msl/generator_impl.cc
@@ -271,9 +271,11 @@
[&](const ast::Const*) {
return true; // Constants are embedded at their use
},
- [&](const ast::Override* override) {
- TINT_DEFER(line());
- return EmitOverride(override);
+ [&](const ast::Override*) {
+ // Override is removed with SubstituteOverride
+ TINT_ICE(Writer, diagnostics_)
+ << "Override should have been removed by the substitute_override transform.";
+ return false;
},
[&](const ast::Function* func) {
TINT_DEFER(line());
@@ -3038,22 +3040,6 @@
return true;
}
-bool GeneratorImpl::EmitOverride(const ast::Override* override) {
- auto* global = program_->Sem().Get<sem::GlobalVariable>(override);
- auto* type = global->Type();
-
- auto out = line();
- out << "constant ";
- if (!EmitType(out, type, program_->Symbols().NameFor(override->symbol))) {
- return false;
- }
- out << " " << program_->Symbols().NameFor(override->symbol);
-
- out << " [[function_constant(" << global->OverrideId().value << ")]];";
-
- return true;
-}
-
GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::Type* ty) {
return Switch(
ty,
diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h
index 72ee42c..a5948aa 100644
--- a/src/tint/writer/msl/generator_impl.h
+++ b/src/tint/writer/msl/generator_impl.h
@@ -356,10 +356,6 @@
/// @param let the variable to generate
/// @returns true if the variable was emitted
bool EmitLet(const ast::Let* let);
- /// Handles generating a module-scope 'override' declaration
- /// @param override the 'override' to emit
- /// @returns true if the variable was emitted
- bool EmitOverride(const ast::Override* override);
/// Emits the zero value for the given type
/// @param out the output of the expression stream
/// @param type the type to emit the value for
diff --git a/src/tint/writer/msl/generator_impl_module_constant_test.cc b/src/tint/writer/msl/generator_impl_module_constant_test.cc
index 9f665fd..0ff3290 100644
--- a/src/tint/writer/msl/generator_impl_module_constant_test.cc
+++ b/src/tint/writer/msl/generator_impl_module_constant_test.cc
@@ -328,27 +328,5 @@
)");
}
-TEST_F(MslGeneratorImplTest, Emit_Override) {
- auto* var = Override("pos", ty.f32(), Expr(3_f), Id(23_a));
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
- EXPECT_EQ(gen.result(), "constant float pos [[function_constant(23)]];\n");
-}
-
-TEST_F(MslGeneratorImplTest, Emit_Override_NoId) {
- auto* var_a = Override("a", ty.f32(), Id(0_a));
- auto* var_b = Override("b", ty.f32());
-
- GeneratorImpl& gen = Build();
-
- ASSERT_TRUE(gen.EmitOverride(var_a)) << gen.error();
- ASSERT_TRUE(gen.EmitOverride(var_b)) << gen.error();
- EXPECT_EQ(gen.result(), R"(constant float a [[function_constant(0)]];
-constant float b [[function_constant(1)]];
-)");
-}
-
} // namespace
} // namespace tint::writer::msl
diff --git a/src/tint/writer/msl/generator_impl_type_test.cc b/src/tint/writer/msl/generator_impl_type_test.cc
index a6901b4..c8f4f6d 100644
--- a/src/tint/writer/msl/generator_impl_type_test.cc
+++ b/src/tint/writer/msl/generator_impl_type_test.cc
@@ -254,8 +254,8 @@
TEST_F(MslGeneratorImplTest, EmitType_Struct_Layout_NonComposites) {
auto* s = Structure(
"S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberSize(32)}),
- Member("b", ty.f32(), utils::Vector{MemberAlign(128_u), MemberSize(128)}),
+ Member("a", ty.i32(), utils::Vector{MemberSize(32_a)}),
+ Member("b", ty.f32(), utils::Vector{MemberAlign(128_u), MemberSize(128_a)}),
Member("c", ty.vec2<f32>()),
Member("d", ty.u32()),
Member("e", ty.vec3<f32>()),
@@ -376,10 +376,11 @@
});
// inner_y: size(516), align(4)
- auto* inner_y = Structure("inner_y", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberSize(512)}),
- Member("b", ty.f32()),
- });
+ auto* inner_y =
+ Structure("inner_y", utils::Vector{
+ Member("a", ty.i32(), utils::Vector{MemberSize(512_a)}),
+ Member("b", ty.f32()),
+ });
auto* s = Structure("S", utils::Vector{
Member("a", ty.i32()),
@@ -595,7 +596,7 @@
TEST_F(MslGeneratorImplTest, AttemptTintPadSymbolCollision) {
auto* s = Structure("S", utils::Vector{
// uses symbols tint_pad_[0..9] and tint_pad_[20..35]
- Member("tint_pad_2", ty.i32(), utils::Vector{MemberSize(32)}),
+ Member("tint_pad_2", ty.i32(), utils::Vector{MemberSize(32_a)}),
Member("tint_pad_20", ty.f32(),
utils::Vector{MemberAlign(128_u), MemberSize(128_u)}),
Member("tint_pad_33", ty.vec2<f32>()),
diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc
index 3ee241d..f51d808 100644
--- a/src/tint/writer/spirv/builder.cc
+++ b/src/tint/writer/spirv/builder.cc
@@ -506,60 +506,13 @@
} else if (func->PipelineStage() == ast::PipelineStage::kCompute) {
auto& wgsize = func_sem->WorkgroupSize();
- // Check if the workgroup_size uses pipeline-overridable constants.
- if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
- wgsize[2].overridable_const) {
- if (has_overridable_workgroup_size_) {
- // Only one stage can have a pipeline-overridable workgroup size.
- // TODO(crbug.com/tint/810): Use LocalSizeId to handle this scenario.
- TINT_ICE(Writer, builder_.Diagnostics())
- << "multiple stages using pipeline-overridable workgroup sizes";
- }
- has_overridable_workgroup_size_ = true;
-
- auto* vec3_u32 = builder_.create<sem::Vector>(builder_.create<sem::U32>(), 3u);
- uint32_t vec3_u32_type_id = GenerateTypeIfNeeded(vec3_u32);
- if (vec3_u32_type_id == 0) {
- return 0;
- }
-
- OperandList wgsize_ops;
- auto wgsize_result = result_op();
- wgsize_ops.push_back(Operand(vec3_u32_type_id));
- wgsize_ops.push_back(wgsize_result);
-
- // Generate OpConstant instructions for each dimension.
- for (size_t i = 0; i < 3; i++) {
- auto constant = ScalarConstant::U32(wgsize[i].value);
- if (wgsize[i].overridable_const) {
- // Make the constant specializable.
- auto* sem_const =
- builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
- if (!sem_const->Declaration()->Is<ast::Override>()) {
- TINT_ICE(Writer, builder_.Diagnostics())
- << "expected a pipeline-overridable constant";
- }
- constant.is_spec_op = true;
- constant.constant_id = sem_const->OverrideId().value;
- }
-
- auto result = GenerateConstantIfNeeded(constant);
- wgsize_ops.push_back(Operand(result));
- }
-
- // Generate the WorkgroupSize builtin.
- push_type(spv::Op::OpSpecConstantComposite, wgsize_ops);
- push_annot(spv::Op::OpDecorate, {wgsize_result, U32Operand(SpvDecorationBuiltIn),
- U32Operand(SpvBuiltInWorkgroupSize)});
- } else {
- // Not overridable, so just use OpExecutionMode LocalSize.
- uint32_t x = wgsize[0].value;
- uint32_t y = wgsize[1].value;
- uint32_t z = wgsize[2].value;
- push_execution_mode(spv::Op::OpExecutionMode,
- {Operand(id), U32Operand(SpvExecutionModeLocalSize), Operand(x),
- Operand(y), Operand(z)});
- }
+ // SubstituteOverride replaced all overrides with constants.
+ uint32_t x = wgsize[0].value;
+ uint32_t y = wgsize[1].value;
+ uint32_t z = wgsize[2].value;
+ push_execution_mode(spv::Op::OpExecutionMode,
+ {Operand(id), U32Operand(SpvExecutionModeLocalSize), Operand(x),
+ Operand(y), Operand(z)});
}
for (auto builtin : func_sem->TransitivelyReferencedBuiltinVariables()) {
@@ -585,7 +538,7 @@
[&](const ast::BitcastExpression* b) { return GenerateBitcastExpression(b); },
[&](const ast::CallExpression* c) { return GenerateCallExpression(c); },
[&](const ast::IdentifierExpression* i) { return GenerateIdentifierExpression(i); },
- [&](const ast::LiteralExpression* l) { return GenerateLiteralIfNeeded(nullptr, l); },
+ [&](const ast::LiteralExpression* l) { return GenerateLiteralIfNeeded(l); },
[&](const ast::MemberAccessorExpression* m) { return GenerateAccessorExpression(m); },
[&](const ast::UnaryOpExpression* u) { return GenerateUnaryOpExpression(u); },
[&](Default) {
@@ -778,46 +731,6 @@
}
}
- if (auto* override = v->As<ast::Override>(); override && !override->constructor) {
- // SPIR-V requires specialization constants to have initializers.
- init_id = Switch(
- type, //
- [&](const sem::F32*) {
- ast::FloatLiteralExpression l(ProgramID{}, ast::NodeID{}, Source{}, 0,
- ast::FloatLiteralExpression::Suffix::kF);
- return GenerateLiteralIfNeeded(override, &l);
- },
- [&](const sem::U32*) {
- ast::IntLiteralExpression l(ProgramID{}, ast::NodeID{}, Source{}, 0,
- ast::IntLiteralExpression::Suffix::kU);
- return GenerateLiteralIfNeeded(override, &l);
- },
- [&](const sem::I32*) {
- ast::IntLiteralExpression l(ProgramID{}, ast::NodeID{}, Source{}, 0,
- ast::IntLiteralExpression::Suffix::kI);
- return GenerateLiteralIfNeeded(override, &l);
- },
- [&](const sem::Bool*) {
- ast::BoolLiteralExpression l(ProgramID{}, ast::NodeID{}, Source{}, false);
- return GenerateLiteralIfNeeded(override, &l);
- },
- [&](Default) {
- error_ = "invalid type for pipeline constant ID, must be scalar";
- return 0;
- });
- if (init_id == 0) {
- return 0;
- }
- }
-
- if (v->Is<ast::Override>()) {
- push_debug(spv::Op::OpName,
- {Operand(init_id), Operand(builder_.Symbols().NameFor(v->symbol))});
-
- RegisterVariable(sem, init_id);
- return true;
- }
-
auto result = result_op();
auto var_id = std::get<uint32_t>(result);
@@ -1293,15 +1206,9 @@
uint32_t Builder::GenerateConstructorExpression(const ast::Variable* var,
const ast::Expression* expr) {
- if (Is<ast::Override>(var)) {
- if (auto* literal = expr->As<ast::LiteralExpression>()) {
- return GenerateLiteralIfNeeded(var, literal);
- }
- } else {
- if (auto* sem = builder_.Sem().Get(expr)) {
- if (auto constant = sem->ConstantValue()) {
- return GenerateConstantIfNeeded(constant);
- }
+ if (auto* sem = builder_.Sem().Get(expr)) {
+ if (auto constant = sem->ConstantValue()) {
+ return GenerateConstantIfNeeded(constant);
}
}
if (auto* call = builder_.Sem().Get<sem::Call>(expr)) {
@@ -1346,24 +1253,6 @@
// Generate the zero initializer if there are no values provided.
if (args.IsEmpty()) {
- if (global_var && global_var->Declaration()->Is<ast::Override>()) {
- auto constant_id = global_var->OverrideId().value;
- if (result_type->Is<sem::I32>()) {
- return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id));
- }
- if (result_type->Is<sem::U32>()) {
- return GenerateConstantIfNeeded(ScalarConstant::U32(0).AsSpecOp(constant_id));
- }
- if (result_type->Is<sem::F32>()) {
- return GenerateConstantIfNeeded(ScalarConstant::F32(0).AsSpecOp(constant_id));
- }
- if (result_type->Is<sem::F16>()) {
- return GenerateConstantIfNeeded(ScalarConstant::F16(0).AsSpecOp(constant_id));
- }
- if (result_type->Is<sem::Bool>()) {
- return GenerateConstantIfNeeded(ScalarConstant::Bool(false).AsSpecOp(constant_id));
- }
- }
return GenerateConstantNullIfNeeded(result_type->UnwrapRef());
}
@@ -1679,16 +1568,8 @@
return result_id;
}
-uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
- const ast::LiteralExpression* lit) {
+uint32_t Builder::GenerateLiteralIfNeeded(const ast::LiteralExpression* lit) {
ScalarConstant constant;
-
- auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
- if (global && global->Declaration()->Is<ast::Override>()) {
- constant.is_spec_op = true;
- constant.constant_id = global->OverrideId().value;
- }
-
Switch(
lit,
[&](const ast::BoolLiteralExpression* l) {
@@ -1837,42 +1718,30 @@
auto result = result_op();
auto result_id = std::get<uint32_t>(result);
- if (constant.is_spec_op) {
- push_annot(spv::Op::OpDecorate, {Operand(result_id), U32Operand(SpvDecorationSpecId),
- Operand(constant.constant_id)});
- }
-
switch (constant.kind) {
case ScalarConstant::Kind::kU32: {
- push_type(constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
- {Operand(type_id), result, Operand(constant.value.u32)});
+ push_type(spv::Op::OpConstant, {Operand(type_id), result, Operand(constant.value.u32)});
break;
}
case ScalarConstant::Kind::kI32: {
- push_type(constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+ push_type(spv::Op::OpConstant,
{Operand(type_id), result, U32Operand(constant.value.i32)});
break;
}
case ScalarConstant::Kind::kF32: {
- push_type(constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
- {Operand(type_id), result, Operand(constant.value.f32)});
+ push_type(spv::Op::OpConstant, {Operand(type_id), result, Operand(constant.value.f32)});
break;
}
case ScalarConstant::Kind::kF16: {
- push_type(
- constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
- {Operand(type_id), result, U32Operand(constant.value.f16.bits_representation)});
+ push_type(spv::Op::OpConstant, {Operand(type_id), result,
+ U32Operand(constant.value.f16.bits_representation)});
break;
}
case ScalarConstant::Kind::kBool: {
if (constant.value.b) {
- push_type(
- constant.is_spec_op ? spv::Op::OpSpecConstantTrue : spv::Op::OpConstantTrue,
- {Operand(type_id), result});
+ push_type(spv::Op::OpConstantTrue, {Operand(type_id), result});
} else {
- push_type(
- constant.is_spec_op ? spv::Op::OpSpecConstantFalse : spv::Op::OpConstantFalse,
- {Operand(type_id), result});
+ push_type(spv::Op::OpConstantFalse, {Operand(type_id), result});
}
break;
}
diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h
index a0eff10..ba88443 100644
--- a/src/tint/writer/spirv/builder.h
+++ b/src/tint/writer/spirv/builder.h
@@ -333,10 +333,9 @@
/// @returns the ID of the expression or 0 on failure.
uint32_t GenerateConstructorExpression(const ast::Variable* var, const ast::Expression* expr);
/// Generates a literal constant if needed
- /// @param var the variable generated for, nullptr if no variable associated.
/// @param lit the literal to generate
/// @returns the ID on success or 0 on failure
- uint32_t GenerateLiteralIfNeeded(const ast::Variable* var, const ast::LiteralExpression* lit);
+ uint32_t GenerateLiteralIfNeeded(const ast::LiteralExpression* lit);
/// Generates a binary expression
/// @param expr the expression to generate
/// @returns the expression ID on success or 0 otherwise
@@ -625,7 +624,6 @@
std::vector<uint32_t> merge_stack_;
std::vector<uint32_t> continue_stack_;
std::unordered_set<uint32_t> capability_set_;
- bool has_overridable_workgroup_size_ = false;
bool zero_initialize_workgroup_memory_ = false;
struct ContinuingInfo {
diff --git a/src/tint/writer/spirv/builder_function_attribute_test.cc b/src/tint/writer/spirv/builder_function_attribute_test.cc
index 8825c79..9e6c338 100644
--- a/src/tint/writer/spirv/builder_function_attribute_test.cc
+++ b/src/tint/writer/spirv/builder_function_attribute_test.cc
@@ -149,63 +149,6 @@
)");
}
-TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_OverridableConst) {
- Override("width", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
- Override("height", ty.i32(), Construct(ty.i32(), 3_i), Id(8_u));
- Override("depth", ty.i32(), Construct(ty.i32(), 4_i), Id(9_u));
- auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
- utils::Vector{
- WorkgroupSize("width", "height", "depth"),
- Stage(ast::PipelineStage::kCompute),
- });
-
- spirv::Builder& b = Build();
-
- ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error();
- EXPECT_EQ(DumpInstructions(b.execution_modes()), "");
- EXPECT_EQ(DumpInstructions(b.types()),
- R"(%2 = OpTypeInt 32 0
-%1 = OpTypeVector %2 3
-%4 = OpSpecConstant %2 2
-%5 = OpSpecConstant %2 3
-%6 = OpSpecConstant %2 4
-%3 = OpSpecConstantComposite %1 %4 %5 %6
-)");
- EXPECT_EQ(DumpInstructions(b.annots()),
- R"(OpDecorate %4 SpecId 7
-OpDecorate %5 SpecId 8
-OpDecorate %6 SpecId 9
-OpDecorate %3 BuiltIn WorkgroupSize
-)");
-}
-
-TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_LiteralAndConst) {
- Override("height", ty.i32(), Construct(ty.i32(), 2_i), Id(7_u));
- GlobalConst("depth", ty.i32(), Construct(ty.i32(), 3_i));
- auto* func = Func("main", utils::Empty, ty.void_(), utils::Empty,
- utils::Vector{
- WorkgroupSize(4_i, "height", "depth"),
- Stage(ast::PipelineStage::kCompute),
- });
-
- spirv::Builder& b = Build();
-
- ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error();
- EXPECT_EQ(DumpInstructions(b.execution_modes()), "");
- EXPECT_EQ(DumpInstructions(b.types()),
- R"(%2 = OpTypeInt 32 0
-%1 = OpTypeVector %2 3
-%4 = OpConstant %2 4
-%5 = OpSpecConstant %2 2
-%6 = OpConstant %2 3
-%3 = OpSpecConstantComposite %1 %4 %5 %6
-)");
- EXPECT_EQ(DumpInstructions(b.annots()),
- R"(OpDecorate %5 SpecId 7
-OpDecorate %3 BuiltIn WorkgroupSize
-)");
-}
-
TEST_F(BuilderTest, Decoration_ExecutionMode_MultipleFragment) {
auto* func1 = Func("main1", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
diff --git a/src/tint/writer/spirv/builder_global_variable_test.cc b/src/tint/writer/spirv/builder_global_variable_test.cc
index dfb38e5..2c53023 100644
--- a/src/tint/writer/spirv/builder_global_variable_test.cc
+++ b/src/tint/writer/spirv/builder_global_variable_test.cc
@@ -249,146 +249,6 @@
)");
}
-TEST_F(BuilderTest, GlobalVar_Override_Bool) {
- auto* v = Override("var", ty.bool_(), Expr(true), Id(1200_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
-%2 = OpSpecConstantTrue %1
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_Bool_ZeroValue) {
- auto* v = Override("var", ty.bool_(), Construct<bool>(), Id(1200_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
-%2 = OpSpecConstantFalse %1
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_Bool_NoConstructor) {
- auto* v = Override("var", ty.bool_(), Id(1200_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
-%2 = OpSpecConstantFalse %1
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_Scalar) {
- auto* v = Override("var", ty.f32(), Expr(2_f), Id(0_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
-%2 = OpSpecConstant %1 2
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_Scalar_ZeroValue) {
- auto* v = Override("var", ty.f32(), Construct<f32>(), Id(0_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
-%2 = OpSpecConstant %1 0
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_Scalar_F32_NoConstructor) {
- auto* v = Override("var", ty.f32(), Id(0_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
-%2 = OpSpecConstant %1 0
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_Scalar_I32_NoConstructor) {
- auto* v = Override("var", ty.i32(), Id(0_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
-%2 = OpSpecConstant %1 0
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_Scalar_U32_NoConstructor) {
- auto* v = Override("var", ty.u32(), Id(0_a));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "var"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
-%2 = OpSpecConstant %1 0
-)");
-}
-
-TEST_F(BuilderTest, GlobalVar_Override_NoId) {
- auto* var_a = Override("a", ty.bool_(), Expr(true), Id(0_a));
- auto* var_b = Override("b", ty.bool_(), Expr(false));
-
- spirv::Builder& b = Build();
-
- EXPECT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error();
- EXPECT_TRUE(b.GenerateGlobalVariable(var_b)) << b.error();
- EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %2 "a"
-OpName %3 "b"
-)");
- EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
-OpDecorate %3 SpecId 1
-)");
- EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
-%2 = OpSpecConstantTrue %1
-%3 = OpSpecConstantFalse %1
-)");
-}
-
struct BuiltinData {
ast::BuiltinValue builtin;
ast::StorageClass storage;
diff --git a/src/tint/writer/spirv/builder_literal_test.cc b/src/tint/writer/spirv/builder_literal_test.cc
index 374c80b..0d53237 100644
--- a/src/tint/writer/spirv/builder_literal_test.cc
+++ b/src/tint/writer/spirv/builder_literal_test.cc
@@ -27,7 +27,7 @@
spirv::Builder& b = Build();
- auto id = b.GenerateLiteralIfNeeded(nullptr, b_true);
+ auto id = b.GenerateLiteralIfNeeded(b_true);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -42,7 +42,7 @@
spirv::Builder& b = Build();
- auto id = b.GenerateLiteralIfNeeded(nullptr, b_false);
+ auto id = b.GenerateLiteralIfNeeded(b_false);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -58,11 +58,11 @@
spirv::Builder& b = Build();
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, b_true), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(b_true), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, b_false), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(b_false), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, b_true), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(b_true), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
@@ -76,7 +76,7 @@
WrapInFunction(i);
spirv::Builder& b = Build();
- auto id = b.GenerateLiteralIfNeeded(nullptr, i);
+ auto id = b.GenerateLiteralIfNeeded(i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -92,8 +92,8 @@
spirv::Builder& b = Build();
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i1), 0u);
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i2), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i1), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
@@ -107,7 +107,7 @@
spirv::Builder& b = Build();
- auto id = b.GenerateLiteralIfNeeded(nullptr, i);
+ auto id = b.GenerateLiteralIfNeeded(i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -123,8 +123,8 @@
spirv::Builder& b = Build();
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i1), 0u);
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i2), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i1), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
@@ -138,7 +138,7 @@
spirv::Builder& b = Build();
- auto id = b.GenerateLiteralIfNeeded(nullptr, i);
+ auto id = b.GenerateLiteralIfNeeded(i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -154,8 +154,8 @@
spirv::Builder& b = Build();
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i1), 0u);
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i2), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i1), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
@@ -171,7 +171,7 @@
spirv::Builder& b = Build();
- auto id = b.GenerateLiteralIfNeeded(nullptr, i);
+ auto id = b.GenerateLiteralIfNeeded(i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -189,8 +189,8 @@
spirv::Builder& b = Build();
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i1), 0u);
- ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, i2), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i1), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
diff --git a/src/tint/writer/spirv/scalar_constant.h b/src/tint/writer/spirv/scalar_constant.h
index 0629d0c..dd4dfd0 100644
--- a/src/tint/writer/spirv/scalar_constant.h
+++ b/src/tint/writer/spirv/scalar_constant.h
@@ -111,8 +111,7 @@
/// @param rhs the ScalarConstant to compare against
/// @returns true if this ScalarConstant is equal to `rhs`
inline bool operator==(const ScalarConstant& rhs) const {
- return value.u64 == rhs.value.u64 && kind == rhs.kind && is_spec_op == rhs.is_spec_op &&
- constant_id == rhs.constant_id;
+ return value.u64 == rhs.value.u64 && kind == rhs.kind;
}
/// Inequality operator
@@ -120,24 +119,10 @@
/// @returns true if this ScalarConstant is not equal to `rhs`
inline bool operator!=(const ScalarConstant& rhs) const { return !(*this == rhs); }
- /// @returns this ScalarConstant as a specialization op with the given
- /// specialization constant identifier
- /// @param id the constant identifier
- ScalarConstant AsSpecOp(uint32_t id) const {
- auto ret = *this;
- ret.is_spec_op = true;
- ret.constant_id = id;
- return ret;
- }
-
/// The constant value
Value value;
/// The constant value kind
Kind kind = Kind::kBool;
- /// True if the constant is a specialization op
- bool is_spec_op = false;
- /// The identifier if a specialization op
- uint32_t constant_id = 0;
};
} // namespace tint::writer::spirv
diff --git a/src/tint/writer/spirv/scalar_constant_test.cc b/src/tint/writer/spirv/scalar_constant_test.cc
index b00f82a..d38a050 100644
--- a/src/tint/writer/spirv/scalar_constant_test.cc
+++ b/src/tint/writer/spirv/scalar_constant_test.cc
@@ -34,16 +34,6 @@
EXPECT_NE(a, b);
b.value.b = true;
EXPECT_EQ(a, b);
-
- a.is_spec_op = true;
- EXPECT_NE(a, b);
- b.is_spec_op = true;
- EXPECT_EQ(a, b);
-
- a.constant_id = 3;
- EXPECT_NE(a, b);
- b.constant_id = 3;
- EXPECT_EQ(a, b);
}
TEST_F(SpirvScalarConstantTest, U32) {
diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc
index b2fba29..83a96d9 100644
--- a/src/tint/writer/wgsl/generator_impl.cc
+++ b/src/tint/writer/wgsl/generator_impl.cc
@@ -788,7 +788,11 @@
return true;
},
[&](const ast::StructMemberSizeAttribute* size) {
- out << "size(" << size->size << ")";
+ out << "size(";
+ if (!EmitExpression(out, size->expr)) {
+ return false;
+ }
+ out << ")";
return true;
},
[&](const ast::StructMemberAlignAttribute* align) {
diff --git a/src/tint/writer/wgsl/generator_impl_type_test.cc b/src/tint/writer/wgsl/generator_impl_type_test.cc
index ef90579..f6f39d4 100644
--- a/src/tint/writer/wgsl/generator_impl_type_test.cc
+++ b/src/tint/writer/wgsl/generator_impl_type_test.cc
@@ -178,8 +178,8 @@
TEST_F(WgslGeneratorImplTest, EmitType_StructOffsetDecl) {
auto* s = Structure("S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberOffset(8)}),
- Member("b", ty.f32(), utils::Vector{MemberOffset(16)}),
+ Member("a", ty.i32(), utils::Vector{MemberOffset(8_a)}),
+ Member("b", ty.f32(), utils::Vector{MemberOffset(16_a)}),
});
GeneratorImpl& gen = Build();
@@ -199,8 +199,8 @@
TEST_F(WgslGeneratorImplTest, EmitType_StructOffsetDecl_WithSymbolCollisions) {
auto* s =
Structure("S", utils::Vector{
- Member("tint_0_padding", ty.i32(), utils::Vector{MemberOffset(8)}),
- Member("tint_2_padding", ty.f32(), utils::Vector{MemberOffset(16)}),
+ Member("tint_0_padding", ty.i32(), utils::Vector{MemberOffset(8_a)}),
+ Member("tint_2_padding", ty.f32(), utils::Vector{MemberOffset(16_a)}),
});
GeneratorImpl& gen = Build();
@@ -237,8 +237,8 @@
TEST_F(WgslGeneratorImplTest, EmitType_StructSizeDecl) {
auto* s = Structure("S", utils::Vector{
- Member("a", ty.i32(), utils::Vector{MemberSize(16)}),
- Member("b", ty.f32(), utils::Vector{MemberSize(32)}),
+ Member("a", ty.i32(), utils::Vector{MemberSize(16_a)}),
+ Member("b", ty.f32(), utils::Vector{MemberSize(32_a)}),
});
GeneratorImpl& gen = Build();