Import Tint changes from Dawn
Changes:
- 7ab964aa333f50c5f0d7b4c60d7f17637825e247 Tint: consider subgroup_size used in compute stage as uni... by Jiang <zhaoming.jiang@intel.com>
- d87ca7ef9f5d4f0402a17c60c570a05d1daa191c [tint] Mark operator bool() as explicit by Ben Clayton <bclayton@google.com>
- 01ef49d3a7c149854233c3631383efad82d675f6 [tint][ir] Add error handling to the binary decoder by Ben Clayton <bclayton@google.com>
- 41f06ff13762dcb2c1713fd88fb925a61a70e7c8 Tint: parse @input_attachment_index by Le Hoang Quyen <lehoangquyen@chromium.org>
- 98ee9df998a887c881816a0417b008081f9bef88 [tint][ir] Fix unit test expected output by Ben Clayton <bclayton@google.com>
- 5fc8c9eef18f94f202ebd13822d7dccffe301d24 [ir] Add FunctionParam::SetType() by James Price <jrprice@google.com>
- 06f9dbc3c133c655ecccf5dbfdca8a88cb068e17 [msl] Support handle types in ModuleScopeVars by James Price <jrprice@google.com>
- 12d5c4dac6f33a29021f722bdb84969aa2f9d1bc Tint: Resolve @input_attachment_index. by Le Hoang Quyen <lehoangquyen@chromium.org>
GitOrigin-RevId: 7ab964aa333f50c5f0d7b4c60d7f17637825e247
Change-Id: I7cb60dda1bf465a6c6905fededef29efb1a24462
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/189664
Reviewed-by: dan sinclair <dsinclair@google.com>
Commit-Queue: dan sinclair <dsinclair@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/tint/lang/core/ir/binary/decode.cc b/src/tint/lang/core/ir/binary/decode.cc
index 1960c60..ada52f9 100644
--- a/src/tint/lang/core/ir/binary/decode.cc
+++ b/src/tint/lang/core/ir/binary/decode.cc
@@ -27,18 +27,29 @@
#include "src/tint/lang/core/ir/binary/decode.h"
+#include <cmath>
+#include <cstdint>
+#include <string>
#include <utility>
#include "src/tint/lang/core/ir/builder.h"
+#include "src/tint/lang/core/ir/control_instruction.h"
#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/type/depth_multisampled_texture.h"
#include "src/tint/lang/core/type/depth_texture.h"
#include "src/tint/lang/core/type/external_texture.h"
+#include "src/tint/lang/core/type/invalid.h"
#include "src/tint/lang/core/type/multisampled_texture.h"
#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/core/type/storage_texture.h"
+#include "src/tint/lang/core/type/vector.h"
+#include "src/tint/utils/containers/hashset.h"
#include "src/tint/utils/containers/transform.h"
+#include "src/tint/utils/diagnostic/diagnostic.h"
#include "src/tint/utils/macros/compiler.h"
+#include "src/tint/utils/result/result.h"
+#include "src/tint/utils/text/string.h"
+#include "src/tint/utils/text/text_style.h"
TINT_BEGIN_DISABLE_PROTOBUF_WARNINGS();
#include "src/tint/lang/core/ir/binary/ir.pb.h"
@@ -66,6 +77,9 @@
Vector<ir::BreakIf*, 32> break_ifs_{};
Vector<ir::Continue*, 32> continues_{};
+ diag::List diags_{};
+ Hashset<std::string, 4> struct_names_{};
+
Result<Module> Decode() {
{
const size_t n = static_cast<size_t>(mod_in_.types().size());
@@ -85,7 +99,7 @@
const size_t n = static_cast<size_t>(mod_in_.blocks().size());
blocks_.Reserve(n);
for (size_t i = 0; i < n; i++) {
- auto id = static_cast<uint32_t>(i + 1);
+ auto id = static_cast<uint32_t>(i);
if (id == mod_in_.root_block()) {
blocks_.Push(mod_out_.root_block);
} else {
@@ -115,28 +129,93 @@
PopulateBlock(blocks_[i], mod_in_.blocks()[static_cast<int>(i)]);
}
- for (auto* exit : exit_ifs_) {
- InferControlInstruction(exit, &ExitIf::SetIf);
- }
- for (auto* exit : exit_switches_) {
- InferControlInstruction(exit, &ExitSwitch::SetSwitch);
- }
- for (auto* exit : exit_loops_) {
- InferControlInstruction(exit, &ExitLoop::SetLoop);
- }
- for (auto* break_ifs : break_ifs_) {
- InferControlInstruction(break_ifs, &BreakIf::SetLoop);
- }
- for (auto* next_iters : next_iterations_) {
- InferControlInstruction(next_iters, &NextIteration::SetLoop);
- }
- for (auto* cont : continues_) {
- InferControlInstruction(cont, &Continue::SetLoop);
+ if (diags_.ContainsErrors()) {
+ // Note: Its not safe to call InferControlInstruction() with a broken IR.
+ return Failure{std::move(diags_)};
}
+ if (CheckBlocks()) {
+ for (auto* exit : exit_ifs_) {
+ InferControlInstruction(exit, &ExitIf::SetIf);
+ }
+ for (auto* exit : exit_switches_) {
+ InferControlInstruction(exit, &ExitSwitch::SetSwitch);
+ }
+ for (auto* exit : exit_loops_) {
+ InferControlInstruction(exit, &ExitLoop::SetLoop);
+ }
+ for (auto* break_ifs : break_ifs_) {
+ InferControlInstruction(break_ifs, &BreakIf::SetLoop);
+ }
+ for (auto* next_iters : next_iterations_) {
+ InferControlInstruction(next_iters, &NextIteration::SetLoop);
+ }
+ for (auto* cont : continues_) {
+ InferControlInstruction(cont, &Continue::SetLoop);
+ }
+ }
+
+ if (diags_.ContainsErrors()) {
+ return Failure{std::move(diags_)};
+ }
return std::move(mod_out_);
}
+ /// Adds a new error to the diagnostics and returns a reference to it
+ diag::Diagnostic& Error() { return diags_.AddError(Source{}); }
+
+ /// Errors if @p number is not finite.
+ /// @returns @p number if finite, otherwise 0.
+ template <typename T>
+ Number<T> CheckFinite(Number<T> number) {
+ if (TINT_UNLIKELY(!std::isfinite(number.value))) {
+ Error() << "value must be finite";
+ return Number<T>{};
+ }
+ return number;
+ }
+
+ /// @returns true if all blocks are reachable, acyclic nesting depth is less than or equal to
+ /// kMaxBlockDepth.
+ bool CheckBlocks() {
+ const size_t kMaxBlockDepth = 128;
+ Vector<std::pair<const ir::Block*, size_t>, 32> pending;
+ pending.Push(std::make_pair(mod_out_.root_block, 0));
+ for (auto& fn : mod_out_.functions) {
+ pending.Push(std::make_pair(fn->Block(), 0));
+ }
+ Hashset<const ir::Block*, 32> seen;
+ while (!pending.IsEmpty()) {
+ const auto block_depth = pending.Pop();
+ const auto* block = block_depth.first;
+ const size_t depth = block_depth.second;
+ if (!seen.Add(block)) {
+ Error() << "cyclic nesting of blocks";
+ return false;
+ }
+ if (depth > kMaxBlockDepth) {
+ Error() << "block nesting exceeds " << kMaxBlockDepth;
+ return false;
+ }
+ for (auto* inst = block->Instructions(); inst; inst = inst->next) {
+ if (auto* ctrl = inst->As<ir::ControlInstruction>()) {
+ ctrl->ForeachBlock([&](const ir::Block* child) {
+ pending.Push(std::make_pair(child, depth + 1));
+ });
+ }
+ }
+ }
+
+ for (auto* block : blocks_) {
+ if (!seen.Contains(block)) {
+ Error() << "unreachable block";
+ return false;
+ }
+ }
+
+ return true;
+ }
+
template <typename EXIT, typename CTRL_INST>
void InferControlInstruction(EXIT* exit, void (EXIT::*set)(CTRL_INST*)) {
for (auto* block = exit->Block(); block;) {
@@ -174,7 +253,10 @@
Vector<FunctionParam*, 8> params_out;
for (auto param_in : fn_in.parameters()) {
- params_out.Push(ValueAs<ir::FunctionParam>(param_in));
+ auto* param_out = ValueAs<FunctionParam>(param_in);
+ if (TINT_LIKELY(param_out)) {
+ params_out.Push(param_out);
+ }
}
if (fn_in.has_return_location()) {
fn_out->SetReturnLocation(Location(fn_in.return_location()));
@@ -189,7 +271,13 @@
fn_out->SetBlock(Block(fn_in.block()));
}
- ir::Function* Function(uint32_t id) { return id > 0 ? mod_out_.functions[id - 1] : nullptr; }
+ ir::Function* Function(uint32_t id) {
+ if (TINT_UNLIKELY(id >= mod_out_.functions.Length())) {
+ Error() << "function id " << id << " out of range";
+ return nullptr;
+ }
+ return mod_out_.functions[id];
+ }
Function::PipelineStage PipelineStage(pb::PipelineStage stage) {
switch (stage) {
@@ -215,28 +303,37 @@
}
void PopulateBlock(ir::Block* block_out, const pb::Block& block_in) {
- if (block_in.is_multi_in()) {
+ if (auto* mib = block_out->As<ir::MultiInBlock>()) {
Vector<ir::BlockParam*, 8> params;
- for (auto param : block_in.parameters()) {
- params.Push(ValueAs<BlockParam>(param));
+ for (auto param_in : block_in.parameters()) {
+ auto* param_out = ValueAs<BlockParam>(param_in);
+ if (TINT_LIKELY(param_out)) {
+ params.Push(param_out);
+ }
}
- block_out->As<ir::MultiInBlock>()->SetParams(std::move(params));
+ mib->SetParams(std::move(params));
}
for (auto& inst : block_in.instructions()) {
block_out->Append(Instruction(inst));
}
}
- ir::Block* Block(uint32_t id) { return id > 0 ? blocks_[id - 1] : nullptr; }
+ ir::Block* Block(uint32_t id) {
+ if (TINT_UNLIKELY(id >= blocks_.Length())) {
+ Error() << "block id " << id << " out of range";
+ return b.Block();
+ }
+ return blocks_[id];
+ }
template <typename T>
T* BlockAs(uint32_t id) {
auto* block = Block(id);
- if (auto cast = block->As<T>(); TINT_LIKELY(cast)) {
+ if (auto cast = As<T>(block); TINT_LIKELY(cast)) {
return cast;
}
- TINT_ICE() << "block " << id << " is " << (block ? block->TypeInfo().name : "<null>")
- << " expected " << TypeInfo::Of<T>().name;
+ Error() << "block " << id << " is " << (block ? block->TypeInfo().name : "<null>")
+ << " expected " << TypeInfo::Of<T>().name;
return nullptr;
}
@@ -330,6 +427,11 @@
case pb::Instruction::KindCase::KIND_NOT_SET:
break;
}
+ if (!inst_out) {
+ Error() << "invalid Instruction.kind: " << std::to_string(inst_in.kind_case());
+ return b.Let(mod_out_.Types().invalid());
+ }
+
TINT_ASSERT(inst_out);
Vector<ir::Value*, 4> operands;
@@ -345,8 +447,15 @@
inst_out->SetResults(std::move(results));
if (inst_in.has_break_if()) {
- static_cast<BreakIf*>(inst_out)->SetNumNextIterValues(
- inst_in.break_if().num_next_iter_values());
+ auto num_next_iter_values = inst_in.break_if().num_next_iter_values();
+ bool is_valid =
+ inst_out->Operands().Length() >= num_next_iter_values + BreakIf::kArgsOperandOffset;
+ if (TINT_LIKELY(is_valid)) {
+ static_cast<BreakIf*>(inst_out)->SetNumNextIterValues(
+ inst_in.break_if().num_next_iter_values());
+ } else {
+ Error() << "invalid value for num_next_iter_values()";
+ }
}
return inst_out;
@@ -416,12 +525,8 @@
ir::If* CreateInstructionIf(const pb::InstructionIf& if_in) {
auto* if_out = mod_out_.allocators.instructions.Create<ir::If>();
- if (if_in.has_true_()) {
- if_out->SetTrue(Block(if_in.true_()));
- }
- if (if_in.has_false_()) {
- if_out->SetFalse(Block(if_in.false_()));
- }
+ if_out->SetTrue(if_in.has_true_() ? Block(if_in.true_()) : b.Block());
+ if_out->SetFalse(if_in.has_false_() ? Block(if_in.false_()) : b.Block());
return if_out;
}
@@ -440,16 +545,16 @@
ir::Loop* CreateInstructionLoop(const pb::InstructionLoop& loop_in) {
auto* loop_out = mod_out_.allocators.instructions.Create<ir::Loop>();
- if (loop_in.has_initalizer()) {
- loop_out->SetInitializer(Block(loop_in.initalizer()));
+ if (loop_in.has_initializer()) {
+ loop_out->SetInitializer(Block(loop_in.initializer()));
} else {
- loop_out->SetInitializer(mod_out_.blocks.Create());
+ loop_out->SetInitializer(b.Block());
}
loop_out->SetBody(BlockAs<ir::MultiInBlock>(loop_in.body()));
if (loop_in.has_continuing()) {
loop_out->SetContinuing(BlockAs<ir::MultiInBlock>(loop_in.continuing()));
} else {
- loop_out->SetContinuing(mod_out_.blocks.Create<ir::MultiInBlock>());
+ loop_out->SetContinuing(b.MultiInBlock());
}
return loop_out;
}
@@ -491,7 +596,7 @@
case_out.block->SetParent(switch_out);
for (auto selector_in : case_in.selectors()) {
ir::Switch::CaseSelector selector_out{};
- selector_out.val = b.Constant(ConstantValue(selector_in));
+ selector_out.val = Constant(selector_in);
case_out.selectors.Push(std::move(selector_out));
}
if (case_in.is_default()) {
@@ -562,7 +667,9 @@
case pb::Type::KindCase::KIND_NOT_SET:
break;
}
- TINT_ICE() << type_in.kind_case();
+
+ Error() << "invalid Type.kind: " << std::to_string(type_in.kind_case());
+ return mod_out_.Types().invalid();
}
const type::Type* CreateTypeBasic(pb::TypeBasic basic_in) {
@@ -584,15 +691,28 @@
case pb::TypeBasic::TypeBasic_INT_MAX_SENTINEL_DO_NOT_USE_:
break;
}
- TINT_ICE() << "invalid TypeBasic: " << basic_in;
+
+ Error() << "invalid TypeBasic: " << std::to_string(basic_in);
+ return mod_out_.Types().invalid();
}
- const type::Vector* CreateTypeVector(const pb::TypeVector& vector_in) {
+ const type::Type* CreateTypeVector(const pb::TypeVector& vector_in) {
+ const auto width = vector_in.width();
+ if (TINT_UNLIKELY(width < 2 || width > 4)) {
+ Error() << "invalid vector width";
+ return mod_out_.Types().invalid();
+ }
auto* el_ty = Type(vector_in.element_type());
return mod_out_.Types().vec(el_ty, vector_in.width());
}
- const type::Matrix* CreateTypeMatrix(const pb::TypeMatrix& matrix_in) {
+ const type::Type* CreateTypeMatrix(const pb::TypeMatrix& matrix_in) {
+ const auto rows = matrix_in.num_rows();
+ const auto cols = matrix_in.num_columns();
+ if (TINT_UNLIKELY(rows < 2 || rows > 4 || cols < 2 || cols > 4)) {
+ Error() << "invalid matrix dimensions";
+ return mod_out_.Types().invalid();
+ }
auto* el_ty = Type(matrix_in.element_type());
auto* column_ty = mod_out_.Types().vec(el_ty, matrix_in.num_rows());
return mod_out_.Types().mat(column_ty, matrix_in.num_columns());
@@ -605,15 +725,38 @@
return mod_out_.Types().ptr(address_space, store_ty, access);
}
- const type::Struct* CreateTypeStruct(const pb::TypeStruct& struct_in) {
+ const type::Type* CreateTypeStruct(const pb::TypeStruct& struct_in) {
+ auto struct_name = struct_in.name();
+ if (TINT_UNLIKELY(struct_name.empty())) {
+ Error() << "struct must have a name";
+ return mod_out_.Types().invalid();
+ }
+ if (!struct_names_.Add(struct_name)) {
+ Error() << "duplicate struct name: " << style::Type(struct_name);
+ return mod_out_.Types().invalid();
+ }
+
Vector<const core::type::StructMember*, 8> members_out;
uint32_t offset = 0;
for (auto& member_in : struct_in.member()) {
- auto symbol = mod_out_.symbols.Register(member_in.name());
+ auto member_name = member_in.name();
+ if (TINT_UNLIKELY(member_name.empty())) {
+ Error() << "struct member must have a name";
+ return mod_out_.Types().invalid();
+ }
+ auto symbol = mod_out_.symbols.Register(member_name);
auto* type = Type(member_in.type());
auto index = static_cast<uint32_t>(members_out.Length());
auto align = member_in.align();
auto size = member_in.size();
+ if (TINT_UNLIKELY(align == 0)) {
+ Error() << "struct member must have non-zero alignment";
+ align = 1;
+ }
+ if (TINT_UNLIKELY(size == 0)) {
+ Error() << "struct member must have non-zero size";
+ size = 1;
+ }
core::type::StructMemberAttributes attributes_out{};
if (member_in.has_attributes()) {
auto& attributes_in = member_in.attributes();
@@ -641,7 +784,11 @@
offset += size;
members_out.Push(member_out);
}
- auto name = mod_out_.symbols.Register(struct_in.name());
+ if (TINT_UNLIKELY(members_out.IsEmpty())) {
+ Error() << "struct requires at least one member";
+ return mod_out_.Types().invalid();
+ }
+ auto name = mod_out_.symbols.Register(struct_name);
return mod_out_.Types().Struct(name, std::move(members_out));
}
@@ -649,16 +796,29 @@
return mod_out_.Types().atomic(Type(atomic_in.type()));
}
- const type::Array* CreateTypeArray(const pb::TypeArray& array_in) {
+ const type::Type* CreateTypeArray(const pb::TypeArray& array_in) {
auto* element = Type(array_in.element());
- uint32_t stride = static_cast<uint32_t>(array_in.stride());
- uint32_t count = static_cast<uint32_t>(array_in.count());
+ uint32_t stride = array_in.stride();
+ uint32_t count = array_in.count();
+ if (element->Align() == 0 || element->Size() == 0) {
+ Error() << "cannot create an array of an unsized type";
+ return mod_out_.Types().invalid();
+ }
+ uint32_t implicit_stride = tint::RoundUp(element->Align(), element->Size());
+ if (stride < implicit_stride) {
+ Error() << "array element stride is smaller than the implicit stride";
+ return mod_out_.Types().invalid();
+ }
return count > 0 ? mod_out_.Types().array(element, count, stride)
: mod_out_.Types().runtime_array(element, stride);
}
- const type::DepthTexture* CreateTypeDepthTexture(const pb::TypeDepthTexture& texture_in) {
+ const type::Type* CreateTypeDepthTexture(const pb::TypeDepthTexture& texture_in) {
auto dimension = TextureDimension(texture_in.dimension());
+ if (!type::DepthTexture::IsValidDimension(dimension)) {
+ Error() << "invalid DepthTexture dimension";
+ return mod_out_.Types().invalid();
+ }
return mod_out_.Types().Get<type::DepthTexture>(dimension);
}
@@ -675,9 +835,13 @@
return mod_out_.Types().Get<type::MultisampledTexture>(dimension, sub_type);
}
- const type::DepthMultisampledTexture* CreateTypeDepthMultisampledTexture(
+ const type::Type* CreateTypeDepthMultisampledTexture(
const pb::TypeDepthMultisampledTexture& texture_in) {
auto dimension = TextureDimension(texture_in.dimension());
+ if (!type::DepthMultisampledTexture::IsValidDimension(dimension)) {
+ Error() << "invalid DepthMultisampledTexture dimension";
+ return mod_out_.Types().invalid();
+ }
return mod_out_.Types().Get<type::DepthMultisampledTexture>(dimension);
}
@@ -699,7 +863,13 @@
return mod_out_.Types().Get<type::Sampler>(kind);
}
- const type::Type* Type(size_t id) { return id > 0 ? types_[id - 1] : nullptr; }
+ const type::Type* Type(size_t id) {
+ if (TINT_UNLIKELY(id >= types_.Length())) {
+ Error() << "type id " << id << " out of range";
+ return mod_out_.Types().invalid();
+ }
+ return types_[id];
+ }
////////////////////////////////////////////////////////////////////////////
// Values
@@ -720,14 +890,15 @@
value_out = BlockParameter(value_in.block_parameter());
break;
case pb::Value::KindCase::kConstant:
- value_out = b.Constant(ConstantValue(value_in.constant()));
+ value_out = Constant(value_in.constant());
break;
case pb::Value::KindCase::KIND_NOT_SET:
break;
}
if (!value_out) {
- TINT_ICE() << "invalid TypeDecl.kind: " << value_in.kind_case();
+ Error() << "invalid value kind: " << std::to_string(value_in.kind_case());
+ return b.InvalidConstant();
}
return value_out;
@@ -736,7 +907,7 @@
ir::InstructionResult* InstructionResult(const pb::InstructionResult& res_in) {
auto* type = Type(res_in.type());
auto* res_out = b.InstructionResult(type);
- if (res_in.has_name()) {
+ if (!res_in.name().empty()) {
mod_out_.SetName(res_out, res_in.name());
}
return res_out;
@@ -745,7 +916,7 @@
ir::FunctionParam* FunctionParameter(const pb::FunctionParameter& param_in) {
auto* type = Type(param_in.type());
auto* param_out = b.FunctionParam(type);
- if (param_in.has_name()) {
+ if (!param_in.name().empty()) {
mod_out_.SetName(param_out, param_in.name());
}
@@ -772,22 +943,30 @@
ir::BlockParam* BlockParameter(const pb::BlockParameter& param_in) {
auto* type = Type(param_in.type());
auto* param_out = b.BlockParam(type);
- if (param_in.has_name()) {
+ if (!param_in.name().empty()) {
mod_out_.SetName(param_out, param_in.name());
}
return param_out;
}
- ir::Value* Value(uint32_t id) { return id > 0 ? values_[id - 1] : nullptr; }
+ ir::Constant* Constant(uint32_t value_id) { return b.Constant(ConstantValue(value_id)); }
+
+ ir::Value* Value(uint32_t id) {
+ if (TINT_UNLIKELY(id > values_.Length())) {
+ Error() << "value id " << id << " out of range";
+ return nullptr;
+ }
+ return id > 0 ? values_[id - 1] : nullptr;
+ }
template <typename T>
T* ValueAs(uint32_t id) {
auto* value = Value(id);
- if (auto cast = value->As<T>(); TINT_LIKELY(cast)) {
+ if (auto cast = As<T>(value); TINT_LIKELY(cast)) {
return cast;
}
- TINT_ICE() << "Value " << id << " is " << (value ? value->TypeInfo().name : "<null>")
- << " expected " << TypeInfo::Of<T>().name;
+ Error() << "value " << id << " is " << (value ? value->TypeInfo().name : "<null>")
+ << " expected " << TypeInfo::Of<T>().name;
return nullptr;
}
@@ -805,7 +984,8 @@
case pb::ConstantValue::KindCase::KIND_NOT_SET:
break;
}
- TINT_ICE() << "invalid ConstantValue.kind: " << value_in.kind_case();
+ Error() << "invalid ConstantValue.kind: " << std::to_string(value_in.kind_case());
+ return b.InvalidConstant()->Value();
}
const core::constant::Value* CreateConstantScalar(const pb::ConstantValueScalar& value_in) {
@@ -817,33 +997,69 @@
case pb::ConstantValueScalar::KindCase::kU32:
return b.ConstantValue(u32(value_in.u32()));
case pb::ConstantValueScalar::KindCase::kF32:
- return b.ConstantValue(f32(value_in.f32()));
+ return b.ConstantValue(CheckFinite(f32(value_in.f32())));
case pb::ConstantValueScalar::KindCase::kF16:
- return b.ConstantValue(f16(value_in.f16()));
+ return b.ConstantValue(CheckFinite(f16(value_in.f16())));
case pb::ConstantValueScalar::KindCase::KIND_NOT_SET:
break;
}
- TINT_ICE() << "invalid ConstantValueScalar.kind: " << value_in.kind_case();
+ Error() << "invalid ConstantValueScalar.kind: " << std::to_string(value_in.kind_case());
+ return b.InvalidConstant()->Value();
}
const core::constant::Value* CreateConstantComposite(
const pb::ConstantValueComposite& composite_in) {
auto* type = Type(composite_in.type());
+ auto type_elements = type->Elements();
+ size_t num_values = static_cast<size_t>(composite_in.elements().size());
+ if (TINT_UNLIKELY(type_elements.count == 0)) {
+ Error() << "cannot create a composite of type " << type->FriendlyName();
+ return b.InvalidConstant()->Value();
+ }
+ if (TINT_UNLIKELY(type_elements.count != num_values)) {
+ Error() << "constant composite type " << type->FriendlyName() << " expects "
+ << type_elements.count << " elements, but " << num_values << " values encoded";
+ return b.InvalidConstant()->Value();
+ }
Vector<const core::constant::Value*, 8> elements_out;
for (auto element_id : composite_in.elements()) {
- elements_out.Push(ConstantValue(element_id));
+ uint32_t i = static_cast<uint32_t>(elements_out.Length());
+ auto* value = ConstantValue(element_id);
+ if (auto* el_type = type->Element(i); TINT_UNLIKELY(value->Type() != el_type)) {
+ Error() << "constant composite element value type " << value->Type()->FriendlyName()
+ << " does not match element type " << el_type->FriendlyName();
+ return b.InvalidConstant()->Value();
+ }
+ elements_out.Push(value);
}
return mod_out_.constant_values.Composite(type, std::move(elements_out));
}
const core::constant::Value* CreateConstantSplat(const pb::ConstantValueSplat& splat_in) {
auto* type = Type(splat_in.type());
- auto* elem = ConstantValue(splat_in.elements());
- return mod_out_.constant_values.Splat(type, elem);
+ uint32_t num_elements = type->Elements().count;
+ if (TINT_UNLIKELY(num_elements == 0)) {
+ Error() << "cannot create a splat of type " << type->FriendlyName();
+ return b.InvalidConstant()->Value();
+ }
+ auto* value = ConstantValue(splat_in.elements());
+ for (uint32_t i = 0; i < num_elements; i++) {
+ auto* el_type = type->Element(i);
+ if (TINT_UNLIKELY(el_type != value->Type())) {
+ Error() << "constant splat element value type " << value->Type()->FriendlyName()
+ << " does not match element " << i << " type " << el_type->FriendlyName();
+ return b.InvalidConstant()->Value();
+ }
+ }
+ return mod_out_.constant_values.Splat(type, value);
}
const core::constant::Value* ConstantValue(uint32_t id) {
- return id > 0 ? constant_values_[id - 1] : nullptr;
+ if (TINT_UNLIKELY(id >= constant_values_.Length())) {
+ Error() << "constant value id " << id << " out of range";
+ return b.InvalidConstant()->Value();
+ }
+ return constant_values_[id];
}
////////////////////////////////////////////////////////////////////////////
diff --git a/src/tint/lang/core/ir/binary/encode.cc b/src/tint/lang/core/ir/binary/encode.cc
index d908788..05f06bd 100644
--- a/src/tint/lang/core/ir/binary/encode.cc
+++ b/src/tint/lang/core/ir/binary/encode.cc
@@ -110,7 +110,7 @@
}
Vector<pb::Function*, 8> fns_out;
for (auto& fn_in : mod_in_.functions) {
- uint32_t id = static_cast<uint32_t>(fns_out.Length() + 1);
+ uint32_t id = static_cast<uint32_t>(mod_out_.functions().size());
fns_out.Push(mod_out_.add_functions());
functions_.Add(fn_in, id);
}
@@ -153,7 +153,7 @@
fn_out->set_block(Block(fn_in->Block()));
}
- uint32_t Function(const ir::Function* fn_in) { return fn_in ? *functions_.Get(fn_in) : 0; }
+ uint32_t Function(const ir::Function* fn_in) { return *functions_.Get(fn_in); }
pb::PipelineStage PipelineStage(Function::PipelineStage stage) {
switch (stage) {
@@ -173,12 +173,11 @@
// Blocks
////////////////////////////////////////////////////////////////////////////
uint32_t Block(const ir::Block* block_in) {
- if (block_in == nullptr) {
- return 0;
- }
+ TINT_ASSERT(block_in != nullptr);
+
return blocks_.GetOrAdd(block_in, [&]() -> uint32_t {
+ auto id = static_cast<uint32_t>(mod_out_.blocks().size());
auto& block_out = *mod_out_.add_blocks();
- auto id = static_cast<uint32_t>(blocks_.Count());
for (auto* inst : *block_in) {
Instruction(*block_out.add_instructions(), inst);
}
@@ -296,7 +295,7 @@
void InstructionLoop(pb::InstructionLoop& loop_out, const ir::Loop* loop_in) {
if (loop_in->HasInitializer()) {
- loop_out.set_initalizer(Block(loop_in->Initializer()));
+ loop_out.set_initializer(Block(loop_in->Initializer()));
}
loop_out.set_body(Block(loop_in->Body()));
if (loop_in->HasContinuing()) {
@@ -352,9 +351,7 @@
// Types
////////////////////////////////////////////////////////////////////////////
uint32_t Type(const core::type::Type* type_in) {
- if (type_in == nullptr) {
- return 0;
- }
+ TINT_ASSERT(type_in != nullptr);
return types_.GetOrAdd(type_in, [&]() -> uint32_t {
pb::Type type_out;
tint::Switch(
@@ -393,7 +390,7 @@
TINT_ICE_ON_NO_MATCH);
mod_out_.mutable_types()->Add(std::move(type_out));
- return static_cast<uint32_t>(mod_out_.types().size());
+ return static_cast<uint32_t>(mod_out_.types().size() - 1);
});
}
@@ -564,9 +561,7 @@
// ConstantValues
////////////////////////////////////////////////////////////////////////////
uint32_t ConstantValue(const core::constant::Value* constant_in) {
- if (!constant_in) {
- return 0;
- }
+ TINT_ASSERT(constant_in != nullptr);
return constant_values_.GetOrAdd(constant_in, [&] {
pb::ConstantValue constant_out;
tint::Switch(
@@ -595,7 +590,7 @@
TINT_ICE_ON_NO_MATCH);
mod_out_.mutable_constant_values()->Add(std::move(constant_out));
- return static_cast<uint32_t>(mod_out_.constant_values().size());
+ return static_cast<uint32_t>(mod_out_.constant_values().size() - 1);
});
}
diff --git a/src/tint/lang/core/ir/binary/ir.proto b/src/tint/lang/core/ir/binary/ir.proto
index 31c2821..a5271c3 100644
--- a/src/tint/lang/core/ir/binary/ir.proto
+++ b/src/tint/lang/core/ir/binary/ir.proto
@@ -328,9 +328,9 @@
}
message InstructionLoop {
- optional uint32 initalizer = 1; // Module.blocks
- optional uint32 body = 2; // Module.blocks
- optional uint32 continuing = 3; // Module.blocks
+ optional uint32 initializer = 1; // Module.blocks
+ optional uint32 body = 2; // Module.blocks
+ optional uint32 continuing = 3; // Module.blocks
}
message InstructionExitIf {}
diff --git a/src/tint/lang/core/ir/disassembly.cc b/src/tint/lang/core/ir/disassembly.cc
index 1a8a9a6..6909d45 100644
--- a/src/tint/lang/core/ir/disassembly.cc
+++ b/src/tint/lang/core/ir/disassembly.cc
@@ -320,7 +320,7 @@
out_ << ", ";
}
SourceMarker sm(this);
- out_ << NameOf(p) << ":" << StyleType(p->Type()->FriendlyName());
+ out_ << NameOf(p) << ":" << StyleType(p->Type() ? p->Type()->FriendlyName() : "undef");
sm.Store(p);
EmitParamAttributes(p);
diff --git a/src/tint/lang/core/ir/function_param.cc b/src/tint/lang/core/ir/function_param.cc
index 69b1dc0..00d72e5 100644
--- a/src/tint/lang/core/ir/function_param.cc
+++ b/src/tint/lang/core/ir/function_param.cc
@@ -35,9 +35,7 @@
namespace tint::core::ir {
-FunctionParam::FunctionParam(const core::type::Type* ty) : type_(ty) {
- TINT_ASSERT(ty != nullptr);
-}
+FunctionParam::FunctionParam(const core::type::Type* ty) : type_(ty) {}
FunctionParam::~FunctionParam() = default;
diff --git a/src/tint/lang/core/ir/function_param.h b/src/tint/lang/core/ir/function_param.h
index c9a0cea..6ff2235 100644
--- a/src/tint/lang/core/ir/function_param.h
+++ b/src/tint/lang/core/ir/function_param.h
@@ -73,6 +73,10 @@
/// @returns the type of the var
const core::type::Type* Type() const override { return type_; }
+ /// Sets the type of the parameter to @p type
+ /// @param type the new type of the parameter
+ void SetType(const core::type::Type* type) { type_ = type; }
+
/// @copydoc Value::Clone()
FunctionParam* Clone(CloneContext& ctx) override;
diff --git a/src/tint/lang/core/ir/function_param_test.cc b/src/tint/lang/core/ir/function_param_test.cc
index 80b150f..30db305 100644
--- a/src/tint/lang/core/ir/function_param_test.cc
+++ b/src/tint/lang/core/ir/function_param_test.cc
@@ -36,16 +36,6 @@
using namespace tint::core::number_suffixes; // NOLINT
using IR_FunctionParamTest = IRTestHelper;
-TEST_F(IR_FunctionParamTest, Fail_NullType) {
- EXPECT_DEATH_IF_SUPPORTED(
- {
- Module mod;
- Builder b{mod};
- b.FunctionParam(nullptr);
- },
- "");
-}
-
TEST_F(IR_FunctionParamTest, Fail_SetDuplicateBuiltin) {
EXPECT_DEATH_IF_SUPPORTED(
{
diff --git a/src/tint/lang/core/ir/validator.cc b/src/tint/lang/core/ir/validator.cc
index 4a9a866..2912315 100644
--- a/src/tint/lang/core/ir/validator.cc
+++ b/src/tint/lang/core/ir/validator.cc
@@ -769,6 +769,11 @@
return;
}
+ if (!param->Type()) {
+ AddError(param) << "function parameter has nullptr type";
+ return;
+ }
+
// References not allowed on function signatures even with Capability::kAllowRefTypes
if (HoldsType<type::Reference>(param->Type())) {
AddError(param) << "references are not permitted as parameter types";
diff --git a/src/tint/lang/core/ir/validator_test.cc b/src/tint/lang/core/ir/validator_test.cc
index 820182a..9d5b0f1 100644
--- a/src/tint/lang/core/ir/validator_test.cc
+++ b/src/tint/lang/core/ir/validator_test.cc
@@ -239,6 +239,28 @@
)");
}
+TEST_F(IR_ValidatorTest, Function_ParameterWithNullType) {
+ auto* f = b.Function("my_func", ty.void_());
+ auto* p = b.FunctionParam("my_param", nullptr);
+ f->SetParams({p});
+ f->Block()->Append(b.Return(f));
+
+ auto res = ir::Validate(mod);
+ ASSERT_NE(res, Success);
+ EXPECT_EQ(res.Failure().reason.Str(),
+ R"(:1:17 error: function parameter has nullptr type
+%my_func = func(%my_param:undef):void {
+ ^^^^^^^^^^^^^^^
+
+note: # Disassembly
+%my_func = func(%my_param:undef):void {
+ $B1: {
+ ret
+ }
+}
+)");
+}
+
TEST_F(IR_ValidatorTest, CallToFunctionOutsideModule) {
auto* f = b.Function("f", ty.void_());
auto* g = b.Function("g", ty.void_());
diff --git a/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc b/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc
index 0b2500a..bacc9a0 100644
--- a/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc
+++ b/src/tint/lang/hlsl/writer/ast_raise/decompose_memory_access.cc
@@ -321,10 +321,10 @@
/// BufferAccess describes a single storage or uniform buffer access
struct BufferAccess {
- sem::GlobalVariable const* var = nullptr; // Storage or uniform buffer variable
- Offset const* offset = nullptr; // The byte offset on var
- core::type::Type const* type = nullptr; // The type of the access
- operator bool() const { return var; } // Returns true if valid
+ sem::GlobalVariable const* var = nullptr; // Storage or uniform buffer variable
+ Offset const* offset = nullptr; // The byte offset on var
+ core::type::Type const* type = nullptr; // The type of the access
+ explicit operator bool() const { return var; } // Returns true if valid
};
/// Store describes a single storage or uniform buffer write
diff --git a/src/tint/lang/msl/writer/printer/function_test.cc b/src/tint/lang/msl/writer/printer/function_test.cc
index 0513a48..8ba60c2 100644
--- a/src/tint/lang/msl/writer/printer/function_test.cc
+++ b/src/tint/lang/msl/writer/printer/function_test.cc
@@ -25,6 +25,7 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/msl/writer/printer/helper_test.h"
namespace tint::msl::writer {
@@ -41,7 +42,7 @@
)");
}
-TEST_F(MslPrinterTest, EntryPointParameterBindingPoint) {
+TEST_F(MslPrinterTest, EntryPointParameterBufferBindingPoint) {
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
auto* storage = b.FunctionParam("storage", ty.ptr(core::AddressSpace::kStorage, ty.i32()));
auto* uniform = b.FunctionParam("uniform", ty.ptr(core::AddressSpace::kUniform, ty.i32()));
@@ -57,5 +58,22 @@
)");
}
+TEST_F(MslPrinterTest, EntryPointParameterHandleBindingPoint) {
+ auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ auto* texture = b.FunctionParam("texture", t);
+ auto* sampler = b.FunctionParam("sampler", ty.sampler());
+ texture->SetBindingPoint(0, 1);
+ sampler->SetBindingPoint(0, 2);
+ func->SetParams({texture, sampler});
+ func->Block()->Append(b.Return(func));
+
+ ASSERT_TRUE(Generate()) << err_ << output_;
+ EXPECT_EQ(output_, MetalHeader() + R"(
+fragment void foo(texture2d<float, access::sample> texture [[texture(1)]], sampler sampler [[sampler(2)]]) {
+}
+)");
+}
+
} // namespace
} // namespace tint::msl::writer
diff --git a/src/tint/lang/msl/writer/printer/printer.cc b/src/tint/lang/msl/writer/printer/printer.cc
index 7c58cd4..0ae4046 100644
--- a/src/tint/lang/msl/writer/printer/printer.cc
+++ b/src/tint/lang/msl/writer/printer/printer.cc
@@ -329,16 +329,28 @@
}
if (auto binding_point = param->BindingPoint()) {
- auto ptr = param->Type()->As<core::type::Pointer>();
TINT_ASSERT(binding_point->group == 0);
- switch (ptr->AddressSpace()) {
- case core::AddressSpace::kStorage:
- case core::AddressSpace::kUniform:
- out << " [[buffer(" << binding_point->binding << ")]]";
- break;
- default:
- TINT_UNREACHABLE() << "invalid address space with binding point: "
- << ptr->AddressSpace();
+ if (auto ptr = param->Type()->As<core::type::Pointer>()) {
+ switch (ptr->AddressSpace()) {
+ case core::AddressSpace::kStorage:
+ case core::AddressSpace::kUniform:
+ out << " [[buffer(" << binding_point->binding << ")]]";
+ break;
+ default:
+ TINT_UNREACHABLE() << "invalid address space with binding point: "
+ << ptr->AddressSpace();
+ }
+ } else {
+ // Handle types are declared by value instead of by pointer.
+ Switch(
+ param->Type(),
+ [&](const core::type::Texture*) {
+ out << " [[texture(" << binding_point->binding << ")]]";
+ },
+ [&](const core::type::Sampler*) {
+ out << " [[sampler(" << binding_point->binding << ")]]";
+ },
+ TINT_ICE_ON_NO_MATCH);
}
}
}
@@ -1040,7 +1052,6 @@
switch (sc) {
case core::AddressSpace::kFunction:
case core::AddressSpace::kPrivate:
- case core::AddressSpace::kHandle:
out << "thread";
break;
case core::AddressSpace::kWorkgroup:
diff --git a/src/tint/lang/msl/writer/raise/module_scope_vars.cc b/src/tint/lang/msl/writer/raise/module_scope_vars.cc
index 0cb07c1..00a569f 100644
--- a/src/tint/lang/msl/writer/raise/module_scope_vars.cc
+++ b/src/tint/lang/msl/writer/raise/module_scope_vars.cc
@@ -88,14 +88,32 @@
ProcessFunction(*func);
}
- // Replace uses of each module-scope variable with pointers extracted from the structure.
+ // Replace uses of each module-scope variable with values extracted from the structure.
uint32_t index = 0;
for (auto& var : module_vars) {
- var->Result(0)->ReplaceAllUsesWith([&](core::ir::Usage use) { //
- return GetPointerFromStruct(var, use.instruction, index);
+ Vector<core::ir::Instruction*, 16> to_destroy;
+ auto* ptr = var->Result(0)->Type()->As<core::type::Pointer>();
+ var->Result(0)->ForEachUse([&](core::ir::Usage use) { //
+ auto* extracted_variable = GetVariableFromStruct(var, use.instruction, index);
+
+ // We drop the pointer from handle variables and store them in the struct by value
+ // instead, so remove any load instructions for the handle address space.
+ if (use.instruction->Is<core::ir::Load>() &&
+ ptr->AddressSpace() == core::AddressSpace::kHandle) {
+ use.instruction->Result(0)->ReplaceAllUsesWith(extracted_variable);
+ to_destroy.Push(use.instruction);
+ return;
+ }
+
+ use.instruction->SetOperand(use.operand_index, extracted_variable);
});
var->Destroy();
index++;
+
+ // Clean up instructions that need to be removed.
+ for (auto* inst : to_destroy) {
+ inst->Destroy();
+ }
}
}
@@ -106,6 +124,13 @@
for (auto* global : *ir.root_block) {
if (auto* var = global->As<core::ir::Var>()) {
auto* type = var->Result(0)->Type();
+
+ // Handle types drop the pointer and are passed around by value.
+ auto* ptr = type->As<core::type::Pointer>();
+ if (ptr->AddressSpace() == core::AddressSpace::kHandle) {
+ type = ptr->StoreType();
+ }
+
auto name = ir.NameOf(var);
if (!name) {
name = ir.symbols.New();
@@ -181,6 +206,14 @@
decl = param;
break;
}
+ case core::AddressSpace::kHandle: {
+ // Handle types become function parameters and drop the pointer.
+ auto* param = b.FunctionParam(ptr->UnwrapPtr());
+ param->SetBindingPoint(var->BindingPoint());
+ func->AppendParam(param);
+ decl = param;
+ break;
+ }
default:
TINT_UNREACHABLE() << "unhandled address space: " << ptr->AddressSpace();
}
@@ -218,18 +251,26 @@
return param;
}
- /// Get a pointer from the module-scope variable replacement structure, inserting new access
+ /// Get a variable from the module-scope variable replacement structure, inserting new access
/// instructions before @p inst.
/// @param var the variable to get the replacement for
/// @param inst the instruction that uses the variable
/// @param index the index of the variable in the structure member list
- /// @returns the pointer extracted from the structure
- core::ir::Value* GetPointerFromStruct(core::ir::Var* var,
- core::ir::Instruction* inst,
- uint32_t index) {
+ /// @returns the variable extracted from the structure
+ core::ir::Value* GetVariableFromStruct(core::ir::Var* var,
+ core::ir::Instruction* inst,
+ uint32_t index) {
auto* func = ContainingFunction(inst);
auto* struct_value = function_to_struct_value.GetOr(func, nullptr);
- auto* access = b.Access(var->Result(0)->Type(), struct_value, u32(index));
+ auto* type = var->Result(0)->Type();
+
+ // Handle types drop the pointer and are passed around by value.
+ auto* ptr = type->As<core::type::Pointer>();
+ if (ptr->AddressSpace() == core::AddressSpace::kHandle) {
+ type = ptr->StoreType();
+ }
+
+ auto* access = b.Access(type, struct_value, u32(index));
access->InsertBefore(inst);
return access->Result(0);
}
diff --git a/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc b/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
index 17ffdc1..3bef228 100644
--- a/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
+++ b/src/tint/lang/msl/writer/raise/module_scope_vars_test.cc
@@ -30,6 +30,7 @@
#include <utility>
#include "src/tint/lang/core/ir/transform/helper_test.h"
+#include "src/tint/lang/core/type/sampled_texture.h"
using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT
@@ -306,6 +307,63 @@
EXPECT_EQ(expect, str());
}
+TEST_F(MslWriter_ModuleScopeVarsTest, HandleTypes) {
+ auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+ auto* var_t = b.Var("t", ty.ptr<handle>(t));
+ auto* var_s = b.Var("s", ty.ptr<handle>(ty.sampler()));
+ var_t->SetBindingPoint(1, 2);
+ var_s->SetBindingPoint(3, 4);
+ mod.root_block->Append(var_t);
+ mod.root_block->Append(var_s);
+
+ auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ auto* load_t = b.Load(var_t);
+ auto* load_s = b.Load(var_s);
+ b.Call<vec4<f32>>(core::BuiltinFn::kTextureSample, load_t, load_s, b.Splat<vec2<f32>>(0_f));
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %t:ptr<handle, texture_2d<f32>, read> = var @binding_point(1, 2)
+ %s:ptr<handle, sampler, read> = var @binding_point(3, 4)
+}
+
+%foo = @fragment func():void {
+ $B2: {
+ %4:texture_2d<f32> = load %t
+ %5:sampler = load %s
+ %6:vec4<f32> = textureSample %4, %5, vec2<f32>(0.0f)
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+tint_module_vars_struct = struct @align(1) {
+ t:texture_2d<f32> @offset(0)
+ s:sampler @offset(0)
+}
+
+%foo = @fragment func(%t:texture_2d<f32> [@binding_point(1, 2)], %s:sampler [@binding_point(3, 4)]):void {
+ $B1: {
+ %4:tint_module_vars_struct = construct %t, %s
+ %tint_module_vars:tint_module_vars_struct = let %4
+ %6:texture_2d<f32> = access %tint_module_vars, 0u
+ %7:sampler = access %tint_module_vars, 1u
+ %8:vec4<f32> = textureSample %6, %7, vec2<f32>(0.0f)
+ ret
+ }
+}
+)";
+
+ Run(ModuleScopeVars);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(MslWriter_ModuleScopeVarsTest, MultipleAddressSpaces) {
auto* var_a = b.Var("a", ty.ptr<uniform, i32, core::Access::kRead>());
auto* var_b = b.Var("b", ty.ptr<storage, i32, core::Access::kReadWrite>());
@@ -607,6 +665,84 @@
EXPECT_EQ(expect, str());
}
+TEST_F(MslWriter_ModuleScopeVarsTest, CallFunctionThatUsesVars_HandleTypes) {
+ auto* t = ty.Get<core::type::SampledTexture>(core::type::TextureDimension::k2d, ty.f32());
+ auto* var_t = b.Var("t", ty.ptr<handle>(t));
+ auto* var_s = b.Var("s", ty.ptr<handle>(ty.sampler()));
+ var_t->SetBindingPoint(1, 2);
+ var_s->SetBindingPoint(3, 4);
+ mod.root_block->Append(var_t);
+ mod.root_block->Append(var_s);
+
+ auto* foo = b.Function("foo", ty.vec4<f32>());
+ auto* param = b.FunctionParam<i32>("param");
+ foo->SetParams({param});
+ b.Append(foo->Block(), [&] {
+ auto* load_t = b.Load(var_t);
+ auto* load_s = b.Load(var_s);
+ auto* result = b.Call<vec4<f32>>(core::BuiltinFn::kTextureSample, load_t, load_s,
+ b.Splat<vec2<f32>>(0_f));
+ b.Return(foo, result);
+ });
+
+ auto* func = b.Function("main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
+ b.Append(func->Block(), [&] {
+ b.Call(foo, 42_i);
+ b.Return(func);
+ });
+
+ auto* src = R"(
+$B1: { # root
+ %t:ptr<handle, texture_2d<f32>, read> = var @binding_point(1, 2)
+ %s:ptr<handle, sampler, read> = var @binding_point(3, 4)
+}
+
+%foo = func(%param:i32):vec4<f32> {
+ $B2: {
+ %5:texture_2d<f32> = load %t
+ %6:sampler = load %s
+ %7:vec4<f32> = textureSample %5, %6, vec2<f32>(0.0f)
+ ret %7
+ }
+}
+%main = @fragment func():void {
+ $B3: {
+ %9:vec4<f32> = call %foo, 42i
+ ret
+ }
+}
+)";
+ EXPECT_EQ(src, str());
+
+ auto* expect = R"(
+tint_module_vars_struct = struct @align(1) {
+ t:texture_2d<f32> @offset(0)
+ s:sampler @offset(0)
+}
+
+%foo = func(%param:i32, %tint_module_vars:tint_module_vars_struct):vec4<f32> {
+ $B1: {
+ %4:texture_2d<f32> = access %tint_module_vars, 0u
+ %5:sampler = access %tint_module_vars, 1u
+ %6:vec4<f32> = textureSample %4, %5, vec2<f32>(0.0f)
+ ret %6
+ }
+}
+%main = @fragment func(%t:texture_2d<f32> [@binding_point(1, 2)], %s:sampler [@binding_point(3, 4)]):void {
+ $B2: {
+ %10:tint_module_vars_struct = construct %t, %s
+ %tint_module_vars_1:tint_module_vars_struct = let %10 # %tint_module_vars_1: 'tint_module_vars'
+ %12:vec4<f32> = call %foo, 42i, %tint_module_vars_1
+ ret
+ }
+}
+)";
+
+ Run(ModuleScopeVars);
+
+ EXPECT_EQ(expect, str());
+}
+
TEST_F(MslWriter_ModuleScopeVarsTest, CallFunctionThatUsesVars_OutOfOrder) {
auto* var_a = b.Var("a", ty.ptr<storage, i32, core::Access::kRead>());
auto* var_b = b.Var("b", ty.ptr<storage, i32, core::Access::kReadWrite>());
diff --git a/src/tint/lang/spirv/reader/ast_parser/ast_parser.h b/src/tint/lang/spirv/reader/ast_parser/ast_parser.h
index 8f755f7..106221a 100644
--- a/src/tint/lang/spirv/reader/ast_parser/ast_parser.h
+++ b/src/tint/lang/spirv/reader/ast_parser/ast_parser.h
@@ -103,7 +103,7 @@
TypedExpression& operator=(const TypedExpression&);
/// @returns true if both type and expr are not nullptr
- operator bool() const { return type && expr; }
+ explicit operator bool() const { return type && expr; }
/// The type
const Type* type = nullptr;
diff --git a/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
index 0dbbc5b..72efe20 100644
--- a/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
+++ b/src/tint/lang/wgsl/ast/transform/zero_init_workgroup_memory.cc
@@ -119,7 +119,7 @@
ArrayIndices array_indices;
/// @returns true if the expr is not null (null usually indicates a failure)
- operator bool() const { return expr != nullptr; }
+ explicit operator bool() const { return expr != nullptr; }
};
/// Statement holds information about a statement that will zero workgroup
diff --git a/src/tint/lang/wgsl/helpers/append_vector.cc b/src/tint/lang/wgsl/helpers/append_vector.cc
index dab1506..8ea2ebd 100644
--- a/src/tint/lang/wgsl/helpers/append_vector.cc
+++ b/src/tint/lang/wgsl/helpers/append_vector.cc
@@ -45,7 +45,7 @@
struct VectorConstructorInfo {
const sem::Call* call = nullptr;
const sem::ValueConstructor* ctor = nullptr;
- operator bool() const { return call != nullptr; }
+ explicit operator bool() const { return call != nullptr; }
};
VectorConstructorInfo AsVectorConstructor(const sem::ValueExpression* expr) {
if (auto* call = expr->As<sem::Call>()) {
diff --git a/src/tint/lang/wgsl/reader/parser/parser.cc b/src/tint/lang/wgsl/reader/parser/parser.cc
index 3eddee4..b85555e 100644
--- a/src/tint/lang/wgsl/reader/parser/parser.cc
+++ b/src/tint/lang/wgsl/reader/parser/parser.cc
@@ -45,6 +45,7 @@
#include "src/tint/lang/wgsl/ast/id_attribute.h"
#include "src/tint/lang/wgsl/ast/if_statement.h"
#include "src/tint/lang/wgsl/ast/increment_decrement_statement.h"
+#include "src/tint/lang/wgsl/ast/input_attachment_index_attribute.h"
#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
#include "src/tint/lang/wgsl/ast/loop_statement.h"
#include "src/tint/lang/wgsl/ast/return_statement.h"
@@ -3100,6 +3101,8 @@
return create<ast::GroupAttribute>(t.source(), args[0]);
case core::Attribute::kId:
return create<ast::IdAttribute>(t.source(), args[0]);
+ case core::Attribute::kInputAttachmentIndex:
+ return create<ast::InputAttachmentIndexAttribute>(t.source(), args[0]);
case core::Attribute::kInterpolate:
return create<ast::InterpolateAttribute>(t.source(), args[0],
args.Length() == 2 ? args[1] : nullptr);
diff --git a/src/tint/lang/wgsl/reader/parser/variable_attribute_test.cc b/src/tint/lang/wgsl/reader/parser/variable_attribute_test.cc
index e7c48e4..c61f47d 100644
--- a/src/tint/lang/wgsl/reader/parser/variable_attribute_test.cc
+++ b/src/tint/lang/wgsl/reader/parser/variable_attribute_test.cc
@@ -647,5 +647,48 @@
EXPECT_EQ(p->error(), "1:7: expected expression for group");
}
+TEST_F(WGSLParserTest, Attribute_InputAttachmentIndex) {
+ auto p = parser("input_attachment_index(4)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr);
+ auto* var_attr = attr.value->As<ast::Attribute>();
+ ASSERT_FALSE(p->has_error());
+ ASSERT_NE(var_attr, nullptr);
+ ASSERT_TRUE(var_attr->Is<ast::InputAttachmentIndexAttribute>());
+
+ auto* group = var_attr->As<ast::InputAttachmentIndexAttribute>();
+ ASSERT_TRUE(group->expr->Is<ast::IntLiteralExpression>());
+ auto* expr = group->expr->As<ast::IntLiteralExpression>();
+ EXPECT_EQ(expr->value, 4);
+ EXPECT_EQ(expr->suffix, ast::IntLiteralExpression::Suffix::kNone);
+}
+
+TEST_F(WGSLParserTest, Attribute_InputAttachmentIndex_expression) {
+ auto p = parser("input_attachment_index(4 + 5)");
+ auto attr = p->attribute();
+ EXPECT_TRUE(attr.matched);
+ EXPECT_FALSE(attr.errored);
+ ASSERT_NE(attr.value, nullptr);
+ auto* var_attr = attr.value->As<ast::Attribute>();
+ ASSERT_FALSE(p->has_error());
+ ASSERT_NE(var_attr, nullptr);
+ ASSERT_TRUE(var_attr->Is<ast::InputAttachmentIndexAttribute>());
+
+ auto* group = var_attr->As<ast::InputAttachmentIndexAttribute>();
+ ASSERT_TRUE(group->expr->Is<ast::BinaryExpression>());
+ auto* expr = group->expr->As<ast::BinaryExpression>();
+
+ EXPECT_EQ(core::BinaryOp::kAdd, expr->op);
+ auto* v = expr->lhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 4u);
+
+ v = expr->rhs->As<ast::IntLiteralExpression>();
+ ASSERT_NE(nullptr, v);
+ EXPECT_EQ(v->value, 5u);
+}
+
} // namespace
} // namespace tint::wgsl::reader
diff --git a/src/tint/lang/wgsl/resolver/BUILD.bazel b/src/tint/lang/wgsl/resolver/BUILD.bazel
index 9893abb..e5ea130 100644
--- a/src/tint/lang/wgsl/resolver/BUILD.bazel
+++ b/src/tint/lang/wgsl/resolver/BUILD.bazel
@@ -125,6 +125,7 @@
"host_shareable_validation_test.cc",
"increment_decrement_validation_test.cc",
"inferred_type_test.cc",
+ "input_attachments_extension_test.cc",
"is_host_shareable_test.cc",
"is_storeable_test.cc",
"language_features_test.cc",
diff --git a/src/tint/lang/wgsl/resolver/BUILD.cmake b/src/tint/lang/wgsl/resolver/BUILD.cmake
index 429b9c0..326af91 100644
--- a/src/tint/lang/wgsl/resolver/BUILD.cmake
+++ b/src/tint/lang/wgsl/resolver/BUILD.cmake
@@ -123,6 +123,7 @@
lang/wgsl/resolver/host_shareable_validation_test.cc
lang/wgsl/resolver/increment_decrement_validation_test.cc
lang/wgsl/resolver/inferred_type_test.cc
+ lang/wgsl/resolver/input_attachments_extension_test.cc
lang/wgsl/resolver/is_host_shareable_test.cc
lang/wgsl/resolver/is_storeable_test.cc
lang/wgsl/resolver/language_features_test.cc
diff --git a/src/tint/lang/wgsl/resolver/BUILD.gn b/src/tint/lang/wgsl/resolver/BUILD.gn
index 9456dbb..6fb89d0 100644
--- a/src/tint/lang/wgsl/resolver/BUILD.gn
+++ b/src/tint/lang/wgsl/resolver/BUILD.gn
@@ -125,6 +125,7 @@
"host_shareable_validation_test.cc",
"increment_decrement_validation_test.cc",
"inferred_type_test.cc",
+ "input_attachments_extension_test.cc",
"is_host_shareable_test.cc",
"is_storeable_test.cc",
"language_features_test.cc",
diff --git a/src/tint/lang/wgsl/resolver/attribute_validation_test.cc b/src/tint/lang/wgsl/resolver/attribute_validation_test.cc
index e94ebbe..809c8e3 100644
--- a/src/tint/lang/wgsl/resolver/attribute_validation_test.cc
+++ b/src/tint/lang/wgsl/resolver/attribute_validation_test.cc
@@ -65,6 +65,7 @@
kDiagnostic,
kGroup,
kId,
+ kInputAttachmentIndex,
kInterpolate,
kInvariant,
kLocation,
@@ -93,6 +94,8 @@
return o << "@group";
case AttributeKind::kId:
return o << "@id";
+ case AttributeKind::kInputAttachmentIndex:
+ return o << "@input_attachment_index";
case AttributeKind::kInterpolate:
return o << "@interpolate";
case AttributeKind::kInvariant:
@@ -166,6 +169,10 @@
"1:2 error: '@id' is not valid for " + thing,
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ "1:2 error: '@input_attachment_index' is not valid for " + thing,
+ },
+ TestParams{
{AttributeKind::kInterpolate},
"1:2 error: '@interpolate' is not valid for " + thing,
},
@@ -230,6 +237,8 @@
return builder.Group(source, 1_a);
case AttributeKind::kId:
return builder.Id(source, 0_a);
+ case AttributeKind::kInputAttachmentIndex:
+ return builder.InputAttachmentIndex(source, 2_a);
case AttributeKind::kBlendSrc:
return builder.BlendSrc(source, 0_a);
case AttributeKind::kInterpolate:
@@ -264,6 +273,9 @@
case AttributeKind::kBlendSrc:
Enable(wgsl::Extension::kChromiumInternalDualSourceBlending);
break;
+ case AttributeKind::kInputAttachmentIndex:
+ Enable(wgsl::Extension::kChromiumInternalInputAttachments);
+ break;
default:
break;
}
@@ -345,6 +357,10 @@
R"(1:2 error: '@id' is not valid for functions)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for functions)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' is not valid for functions)",
},
@@ -393,81 +409,86 @@
CHECK();
}
-INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
- NonVoidFunctionAttributeTest,
- testing::Values(
- TestParams{
- {AttributeKind::kAlign},
- R"(1:2 error: '@align' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kBinding},
- R"(1:2 error: '@binding' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kBlendSrc},
- R"(1:2 error: '@blend_src' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kBuiltinPosition},
- R"(1:2 error: '@builtin' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kColor},
- R"(1:2 error: '@color' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kDiagnostic},
- Pass,
- },
- TestParams{
- {AttributeKind::kGroup},
- R"(1:2 error: '@group' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kId},
- R"(1:2 error: '@id' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kInterpolate},
- R"(1:2 error: '@interpolate' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kInvariant},
- R"(1:2 error: '@invariant' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kLocation},
- R"(1:2 error: '@location' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kMustUse},
- Pass,
- },
- TestParams{
- {AttributeKind::kOffset},
- R"(1:2 error: '@offset' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kSize},
- R"(1:2 error: '@size' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kStageCompute},
- R"(9:9 error: missing entry point IO attribute on return type)",
- },
- TestParams{
- {AttributeKind::kStageCompute, AttributeKind::kWorkgroupSize},
- R"(9:9 error: missing entry point IO attribute on return type)",
- },
- TestParams{
- {AttributeKind::kStride},
- R"(1:2 error: '@stride' is not valid for functions)",
- },
- TestParams{
- {AttributeKind::kWorkgroupSize},
- R"(1:2 error: '@workgroup_size' is only valid for compute stages)",
- }));
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ NonVoidFunctionAttributeTest,
+ testing::Values(
+ TestParams{
+ {AttributeKind::kAlign},
+ R"(1:2 error: '@align' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kBinding},
+ R"(1:2 error: '@binding' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kBlendSrc},
+ R"(1:2 error: '@blend_src' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kBuiltinPosition},
+ R"(1:2 error: '@builtin' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kColor},
+ R"(1:2 error: '@color' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kDiagnostic},
+ Pass,
+ },
+ TestParams{
+ {AttributeKind::kGroup},
+ R"(1:2 error: '@group' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kId},
+ R"(1:2 error: '@id' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kInterpolate},
+ R"(1:2 error: '@interpolate' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kInvariant},
+ R"(1:2 error: '@invariant' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kLocation},
+ R"(1:2 error: '@location' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kMustUse},
+ Pass,
+ },
+ TestParams{
+ {AttributeKind::kOffset},
+ R"(1:2 error: '@offset' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kSize},
+ R"(1:2 error: '@size' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kStageCompute},
+ R"(9:9 error: missing entry point IO attribute on return type)",
+ },
+ TestParams{
+ {AttributeKind::kStageCompute, AttributeKind::kWorkgroupSize},
+ R"(9:9 error: missing entry point IO attribute on return type)",
+ },
+ TestParams{
+ {AttributeKind::kStride},
+ R"(1:2 error: '@stride' is not valid for functions)",
+ },
+ TestParams{
+ {AttributeKind::kWorkgroupSize},
+ R"(1:2 error: '@workgroup_size' is only valid for compute stages)",
+ }));
} // namespace FunctionTests
namespace FunctionInputAndOutputTests {
@@ -520,6 +541,10 @@
R"(1:2 error: '@id' is not valid for function parameters)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for function parameters)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' is not valid for non-entry point function parameters)",
},
@@ -605,6 +630,10 @@
R"(1:2 error: '@id' is not valid for non-entry point function return types)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for non-entry point function return types)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' is not valid for non-entry point function return types)",
},
@@ -695,6 +724,10 @@
R"(1:2 error: '@id' is not valid for function parameters)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for function parameters)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' cannot be used by compute shaders)",
},
@@ -784,6 +817,10 @@
R"(1:2 error: '@id' is not valid for function parameters)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for function parameters)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(9:9 error: missing entry point IO attribute on parameter)",
},
@@ -887,6 +924,10 @@
R"(1:2 error: '@id' is not valid for function parameters)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for function parameters)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(9:9 error: missing entry point IO attribute on parameter)",
},
@@ -992,6 +1033,10 @@
R"(1:2 error: '@id' is not valid for entry point return types)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for entry point return types)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' cannot be used by compute shaders)",
},
@@ -1082,6 +1127,10 @@
R"(1:2 error: '@id' is not valid for entry point return types)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for entry point return types)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(9:9 error: missing entry point IO attribute on return type)",
},
@@ -1190,6 +1239,10 @@
R"(1:2 error: '@id' is not valid for entry point return types)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for entry point return types)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' can only be used with '@location')",
},
@@ -1315,6 +1368,10 @@
R"(1:2 error: '@id' is not valid for 'struct' declarations)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for 'struct' declarations)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' is not valid for 'struct' declarations)",
},
@@ -1399,6 +1456,10 @@
R"(1:2 error: '@id' is not valid for 'struct' members)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for 'struct' members)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' can only be used with '@location')",
},
@@ -1635,82 +1696,87 @@
CHECK();
}
-INSTANTIATE_TEST_SUITE_P(ResolverAttributeValidationTest,
- ArrayAttributeTest,
- testing::Values(
- TestParams{
- {AttributeKind::kAlign},
- R"(1:2 error: '@align' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kBinding},
- R"(1:2 error: '@binding' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kBlendSrc},
- R"(1:2 error: '@blend_src' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kBuiltinPosition},
- R"(1:2 error: '@builtin' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kDiagnostic},
- R"(1:2 error: '@diagnostic' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kGroup},
- R"(1:2 error: '@group' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kId},
- R"(1:2 error: '@id' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kInterpolate},
- R"(1:2 error: '@interpolate' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kInvariant},
- R"(1:2 error: '@invariant' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kLocation},
- R"(1:2 error: '@location' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kMustUse},
- R"(1:2 error: '@must_use' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kOffset},
- R"(1:2 error: '@offset' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kSize},
- R"(1:2 error: '@size' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kStageCompute},
- R"(1:2 error: '@compute' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kStride},
- Pass,
- },
- TestParams{
- {AttributeKind::kWorkgroupSize},
- R"(1:2 error: '@workgroup_size' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kBinding, AttributeKind::kGroup},
- R"(1:2 error: '@binding' is not valid for 'array' types)",
- },
- TestParams{
- {AttributeKind::kStride, AttributeKind::kStride},
- R"(3:4 error: duplicate stride attribute
+INSTANTIATE_TEST_SUITE_P(
+ ResolverAttributeValidationTest,
+ ArrayAttributeTest,
+ testing::Values(
+ TestParams{
+ {AttributeKind::kAlign},
+ R"(1:2 error: '@align' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kBinding},
+ R"(1:2 error: '@binding' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kBlendSrc},
+ R"(1:2 error: '@blend_src' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kBuiltinPosition},
+ R"(1:2 error: '@builtin' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kDiagnostic},
+ R"(1:2 error: '@diagnostic' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kGroup},
+ R"(1:2 error: '@group' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kId},
+ R"(1:2 error: '@id' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kInterpolate},
+ R"(1:2 error: '@interpolate' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kInvariant},
+ R"(1:2 error: '@invariant' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kLocation},
+ R"(1:2 error: '@location' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kMustUse},
+ R"(1:2 error: '@must_use' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kOffset},
+ R"(1:2 error: '@offset' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kSize},
+ R"(1:2 error: '@size' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kStageCompute},
+ R"(1:2 error: '@compute' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kStride},
+ Pass,
+ },
+ TestParams{
+ {AttributeKind::kWorkgroupSize},
+ R"(1:2 error: '@workgroup_size' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kBinding, AttributeKind::kGroup},
+ R"(1:2 error: '@binding' is not valid for 'array' types)",
+ },
+ TestParams{
+ {AttributeKind::kStride, AttributeKind::kStride},
+ R"(3:4 error: duplicate stride attribute
1:2 note: first attribute declared here)",
- }));
+ }));
using VariableAttributeTest = TestWithParams;
TEST_P(VariableAttributeTest, IsValid) {
@@ -1862,6 +1928,10 @@
R"(1:2 error: '@id' is not valid for 'const' declaration)",
},
TestParams{
+ {AttributeKind::kInputAttachmentIndex},
+ R"(1:2 error: '@input_attachment_index' is not valid for 'const' declaration)",
+ },
+ TestParams{
{AttributeKind::kInterpolate},
R"(1:2 error: '@interpolate' is not valid for 'const' declaration)",
},
diff --git a/src/tint/lang/wgsl/resolver/dependency_graph.cc b/src/tint/lang/wgsl/resolver/dependency_graph.cc
index 4838cde..bffa92e 100644
--- a/src/tint/lang/wgsl/resolver/dependency_graph.cc
+++ b/src/tint/lang/wgsl/resolver/dependency_graph.cc
@@ -52,6 +52,7 @@
#include "src/tint/lang/wgsl/ast/identifier.h"
#include "src/tint/lang/wgsl/ast/if_statement.h"
#include "src/tint/lang/wgsl/ast/increment_decrement_statement.h"
+#include "src/tint/lang/wgsl/ast/input_attachment_index_attribute.h"
#include "src/tint/lang/wgsl/ast/internal_attribute.h"
#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
#include "src/tint/lang/wgsl/ast/invariant_attribute.h"
@@ -379,6 +380,7 @@
[&](const ast::ColorAttribute* color) { TraverseExpression(color->expr); },
[&](const ast::GroupAttribute* group) { TraverseExpression(group->expr); },
[&](const ast::IdAttribute* id) { TraverseExpression(id->expr); },
+ [&](const ast::InputAttachmentIndexAttribute* idx) { TraverseExpression(idx->expr); },
[&](const ast::BlendSrcAttribute* index) { TraverseExpression(index->expr); },
[&](const ast::InterpolateAttribute* interpolate) {
TraverseExpression(interpolate->type);
diff --git a/src/tint/lang/wgsl/resolver/dependency_graph_test.cc b/src/tint/lang/wgsl/resolver/dependency_graph_test.cc
index e435301..f64306a 100644
--- a/src/tint/lang/wgsl/resolver/dependency_graph_test.cc
+++ b/src/tint/lang/wgsl/resolver/dependency_graph_test.cc
@@ -1761,6 +1761,7 @@
GlobalVar(Sym(), ty.sampler(core::type::SamplerKind::kSampler));
GlobalVar(Sym(), ty.i32(), Vector{Binding(V), Group(V)});
+ GlobalVar(Sym(), ty.input_attachment(T), Vector{Binding(V), Group(V), InputAttachmentIndex(V)});
GlobalVar(Sym(), ty.i32(), Vector{Location(V)});
Override(Sym(), ty.i32(), Vector{Id(V)});
diff --git a/src/tint/lang/wgsl/resolver/input_attachments_extension_test.cc b/src/tint/lang/wgsl/resolver/input_attachments_extension_test.cc
new file mode 100644
index 0000000..2b0449d
--- /dev/null
+++ b/src/tint/lang/wgsl/resolver/input_attachments_extension_test.cc
@@ -0,0 +1,166 @@
+// Copyright 2024 The Dawn & Tint Authors
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// 3. Neither the name of the copyright holder nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include "src/tint/lang/wgsl/resolver/resolver.h"
+#include "src/tint/lang/wgsl/resolver/resolver_helper_test.h"
+
+namespace tint::resolver {
+namespace {
+
+using namespace tint::core::fluent_types; // NOLINT
+using namespace tint::core::number_suffixes; // NOLINT
+
+using InputAttachmenExtensionTest = ResolverTest;
+
+// Test that input_attachment cannot be used without extension.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentWithoutExtension) {
+ // @group(0) @binding(0) @input_attachment_index(3)
+ // var input_tex : input_attachment<f32>;
+
+ GlobalVar("input_tex", ty.input_attachment(ty.Of<f32>()),
+ Vector{Binding(0_u), Group(0_u), InputAttachmentIndex(3_u)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: use of 'input_attachment' requires enabling extension 'chromium_internal_input_attachments')");
+}
+
+// Test that input_attachment cannot be declared locally.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentLocalDecl) {
+ // enable chromium_internal_input_attachments;
+ // @fragment fn f() {
+ // var input_tex : input_attachment<f32>;
+ // }
+
+ Enable(Source{{12, 34}}, wgsl::Extension::kChromiumInternalInputAttachments);
+
+ Func("f", Empty, ty.void_(),
+ Vector{
+ Decl(Var("input_tex", ty.input_attachment(ty.Of<f32>()))),
+ },
+ Vector{Stage(ast::PipelineStage::kFragment)});
+ EXPECT_FALSE(r()->Resolve());
+}
+
+// Test that input_attachment cannot be declared without index.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentWithoutIndex) {
+ // enable chromium_internal_input_attachments;
+ // @group(0) @binding(0)
+ // var input_tex : input_attachment<f32>;
+
+ Enable(Source{{12, 34}}, wgsl::Extension::kChromiumInternalInputAttachments);
+
+ GlobalVar("input_tex", ty.input_attachment(ty.Of<f32>()), Vector{Binding(0_u), Group(0_u)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ R"(error: 'input_attachment' variables require '@input_attachment_index' attribute)");
+}
+
+// Test that Resolver can get input_attachment_index value.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentIndexValue) {
+ // enable chromium_internal_input_attachments;
+ // @group(0) @binding(0) @input_attachment_index(3)
+ // var input_tex : input_attachment<f32>;
+
+ Enable(Source{{12, 34}}, wgsl::Extension::kChromiumInternalInputAttachments);
+
+ auto* ast_var = GlobalVar("input_tex", ty.input_attachment(ty.Of<f32>()),
+ Vector{Binding(0_u), Group(0_u), InputAttachmentIndex(3_u)});
+
+ EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* sem_var = Sem().Get<sem::GlobalVariable>(ast_var);
+ ASSERT_NE(sem_var, nullptr);
+ EXPECT_EQ(sem_var->Attributes().input_attachment_index, 3u);
+}
+
+// Test that @input_attachment_index cannot be used without extension.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentIndexWithoutExtension) {
+ // @group(0) @binding(0) @input_attachment_index(3)
+ // var input_tex : texture_2d<f32>;
+
+ GlobalVar("input_tex", ty.sampled_texture(core::type::TextureDimension::k2d, ty.Of<f32>()),
+ Vector{Binding(0_u), Group(0_u), InputAttachmentIndex(3_u)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: use of '@input_attachment_index' requires enabling extension 'chromium_internal_input_attachments')");
+}
+
+// Test that input_attachment_index's value cannot be float.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentIndexInvalidValueType) {
+ // enable chromium_internal_input_attachments;
+ // @group(0) @binding(0) @input_attachment_index(3.0)
+ // var input_tex : input_attachment<f32>;
+
+ Enable(Source{{12, 34}}, wgsl::Extension::kChromiumInternalInputAttachments);
+
+ GlobalVar("input_tex", ty.input_attachment(ty.Of<f32>()),
+ Vector{Binding(0_u), Group(0_u), InputAttachmentIndex(3_f)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(error: '@input_attachment_index' must be an 'i32' or 'u32' value)");
+}
+
+// Test that input_attachment_index's value cannot be negative.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentIndexNegative) {
+ // enable chromium_internal_input_attachments;
+ // @group(0) @binding(0) @input_attachment_index(-2)
+ // var input_tex : input_attachment<f32>;
+
+ Enable(Source{{12, 34}}, wgsl::Extension::kChromiumInternalInputAttachments);
+
+ GlobalVar("input_tex", ty.input_attachment(ty.Of<f32>()),
+ Vector{Binding(0_u), Group(0_u), InputAttachmentIndex(core::i32(-2))});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), R"(error: '@input_attachment_index' value must be non-negative)");
+}
+
+// Test that input_attachment_index cannot be used on non input_attachment variable.
+TEST_F(InputAttachmenExtensionTest, InputAttachmentIndexInvalidType) {
+ // enable chromium_internal_input_attachments;
+ // @group(0) @binding(0) @input_attachment_index(3)
+ // var input_tex : texture_2d<f32>;
+
+ Enable(Source{{12, 34}}, wgsl::Extension::kChromiumInternalInputAttachments);
+
+ GlobalVar("input_tex", ty.sampled_texture(core::type::TextureDimension::k2d, ty.Of<f32>()),
+ Vector{Binding(0_u), Group(0_u), InputAttachmentIndex(3_u)});
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(
+ r()->error(),
+ R"(error: cannot apply '@input_attachment_index' to declaration of type 'texture_2d<f32>'
+note: '@input_attachment_index' must only be applied to declarations of 'input_attachment' type)");
+}
+
+} // namespace
+} // namespace tint::resolver
diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc
index 0a17841..fe24ffc 100644
--- a/src/tint/lang/wgsl/resolver/resolver.cc
+++ b/src/tint/lang/wgsl/resolver/resolver.cc
@@ -65,6 +65,7 @@
#include "src/tint/lang/wgsl/ast/for_loop_statement.h"
#include "src/tint/lang/wgsl/ast/id_attribute.h"
#include "src/tint/lang/wgsl/ast/if_statement.h"
+#include "src/tint/lang/wgsl/ast/input_attachment_index_attribute.h"
#include "src/tint/lang/wgsl/ast/internal_attribute.h"
#include "src/tint/lang/wgsl/ast/interpolate_attribute.h"
#include "src/tint/lang/wgsl/ast/loop_statement.h"
@@ -614,7 +615,7 @@
bool has_io_address_space = sem->AddressSpace() == core::AddressSpace::kIn ||
sem->AddressSpace() == core::AddressSpace::kOut;
- std::optional<uint32_t> group, binding;
+ std::optional<uint32_t> group, binding, input_attachment_index;
for (auto* attribute : var->attributes) {
Mark(attribute);
enum Status { kSuccess, kErrored, kInvalid };
@@ -636,6 +637,14 @@
group = value.Get();
return kSuccess;
},
+ [&](const ast::InputAttachmentIndexAttribute* attr) {
+ auto value = InputAttachmentIndexAttribute(attr);
+ if (value != Success) {
+ return kErrored;
+ }
+ input_attachment_index = value.Get();
+ return kSuccess;
+ },
[&](const ast::LocationAttribute* attr) {
if (!has_io_address_space) {
return kInvalid;
@@ -708,6 +717,10 @@
global->Attributes().binding_point = BindingPoint{group.value(), binding.value()};
}
+ if (input_attachment_index) {
+ global->Attributes().input_attachment_index = input_attachment_index;
+ }
+
} else {
for (auto* attribute : var->attributes) {
Mark(attribute);
@@ -3876,6 +3889,31 @@
return static_cast<uint32_t>(value);
}
+tint::Result<uint32_t> Resolver::InputAttachmentIndexAttribute(
+ const ast::InputAttachmentIndexAttribute* attr) {
+ ExprEvalStageConstraint constraint{core::EvaluationStage::kConstant, "@input_attachment_index"};
+ TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint);
+
+ auto* materialized = Materialize(ValueExpression(attr->expr));
+ if (!materialized) {
+ return Failure{};
+ }
+ if (!materialized->Type()->IsAnyOf<core::type::I32, core::type::U32>()) {
+ AddError(attr->source) << style::Attribute("@input_attachment_index") << " must be an "
+ << style::Type("i32") << " or " << style::Type("u32") << " value";
+ return Failure{};
+ }
+
+ auto const_value = materialized->ConstantValue();
+ auto value = const_value->ValueAs<AInt>();
+ if (value < 0) {
+ AddError(attr->source) << style::Attribute("@input_attachment_index")
+ << " value must be non-negative";
+ return Failure{};
+ }
+ return static_cast<uint32_t>(value);
+}
+
tint::Result<sem::WorkgroupSize> Resolver::WorkgroupAttribute(const ast::WorkgroupAttribute* attr) {
// Set work-group size defaults.
sem::WorkgroupSize ws;
diff --git a/src/tint/lang/wgsl/resolver/resolver.h b/src/tint/lang/wgsl/resolver/resolver.h
index df18f91..1781493 100644
--- a/src/tint/lang/wgsl/resolver/resolver.h
+++ b/src/tint/lang/wgsl/resolver/resolver.h
@@ -442,6 +442,11 @@
/// @returns the group value on success.
tint::Result<uint32_t> GroupAttribute(const ast::GroupAttribute* attr);
+ /// Resolves the `@input_attachment_index` attribute @p attr
+ /// @returns the index value on success.
+ tint::Result<uint32_t> InputAttachmentIndexAttribute(
+ const ast::InputAttachmentIndexAttribute* attr);
+
/// Resolves the `@workgroup_size` attribute @p attr
/// @returns the workgroup size on success.
tint::Result<sem::WorkgroupSize> WorkgroupAttribute(const ast::WorkgroupAttribute* attr);
diff --git a/src/tint/lang/wgsl/resolver/uniformity.cc b/src/tint/lang/wgsl/resolver/uniformity.cc
index d6a5307..2faeb21 100644
--- a/src/tint/lang/wgsl/resolver/uniformity.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity.cc
@@ -1188,14 +1188,21 @@
const ast::IdentifierExpression* ident,
bool load_rule = false) {
// Helper to check if the entry point attribute of `obj` indicates non-uniformity.
- auto has_nonuniform_entry_point_attribute = [&](auto* obj) {
- // Only the num_workgroups and workgroup_id builtins are uniform.
+ auto has_nonuniform_entry_point_attribute = [&](auto* obj, auto* entry_point) {
+ // Only the num_workgroups and workgroup_id builtins, and subgroup_size builtin used in
+ // compute stage are uniform.
if (auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(obj->attributes)) {
auto builtin = b.Sem().Get(builtin_attr)->Value();
if (builtin == core::BuiltinValue::kNumWorkgroups ||
builtin == core::BuiltinValue::kWorkgroupId) {
return false;
}
+ if (builtin == core::BuiltinValue::kSubgroupSize) {
+ // Currently Tint only allow using subgroup_size builtin as a compute shader
+ // input.
+ TINT_ASSERT(entry_point->PipelineStage() == ast::PipelineStage::kCompute);
+ return false;
+ }
}
return true;
};
@@ -1216,14 +1223,16 @@
// is non-uniform.
bool uniform = true;
for (auto* member : str->Members()) {
- if (has_nonuniform_entry_point_attribute(member->Declaration())) {
+ if (has_nonuniform_entry_point_attribute(member->Declaration(),
+ user_func->Declaration())) {
uniform = false;
}
}
node->AddEdge(uniform ? cf : current_function_->may_be_non_uniform);
return std::make_pair(cf, node);
} else {
- if (has_nonuniform_entry_point_attribute(param->Declaration())) {
+ if (has_nonuniform_entry_point_attribute(param->Declaration(),
+ user_func->Declaration())) {
node->AddEdge(current_function_->may_be_non_uniform);
} else {
node->AddEdge(cf);
diff --git a/src/tint/lang/wgsl/resolver/uniformity_test.cc b/src/tint/lang/wgsl/resolver/uniformity_test.cc
index 9c15fa6..092cfd0 100644
--- a/src/tint/lang/wgsl/resolver/uniformity_test.cc
+++ b/src/tint/lang/wgsl/resolver/uniformity_test.cc
@@ -514,7 +514,11 @@
class ComputeBuiltin : public UniformityAnalysisTestBase,
public ::testing::TestWithParam<BuiltinEntry> {};
TEST_P(ComputeBuiltin, AsParam) {
- std::string src = R"(
+ std::string src = std::string((GetParam().name == "subgroup_size")
+ ? R"(enable chromium_experimental_subgroups;
+)"
+ : "") +
+ R"(
@compute @workgroup_size(64)
fn main(@builtin()" + GetParam().name +
R"() b : )" + GetParam().type + R"() {
@@ -545,7 +549,11 @@
}
TEST_P(ComputeBuiltin, InStruct) {
- std::string src = R"(
+ std::string src = std::string((GetParam().name == "subgroup_size")
+ ? R"(enable chromium_experimental_subgroups;
+)"
+ : "") +
+ R"(
struct S {
@builtin()" + GetParam().name +
R"() b : )" + GetParam().type + R"(
@@ -585,7 +593,8 @@
BuiltinEntry{"local_invocation_index", "u32", false},
BuiltinEntry{"global_invocation_id", "vec3<u32>", false},
BuiltinEntry{"workgroup_id", "vec3<u32>", true},
- BuiltinEntry{"num_workgroups", "vec3<u32>", true}),
+ BuiltinEntry{"num_workgroups", "vec3<u32>", true},
+ BuiltinEntry{"subgroup_size", "u32", true}),
[](const ::testing::TestParamInfo<ComputeBuiltin::ParamType>& p) {
return p.param.name;
});
diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc
index f482497..559aa50 100644
--- a/src/tint/lang/wgsl/resolver/validator.cc
+++ b/src/tint/lang/wgsl/resolver/validator.cc
@@ -442,6 +442,12 @@
}
bool Validator::InputAttachment(const core::type::InputAttachment* t, const Source& source) const {
+ if (!enabled_extensions_.Contains(wgsl::Extension::kChromiumInternalInputAttachments)) {
+ AddError(source) << "use of " << style::Type("input_attachment")
+ << " requires enabling extension "
+ << style::Code("chromium_internal_input_attachments");
+ return false;
+ }
if (!t->type()->UnwrapRef()->IsAnyOf<core::type::F32, core::type::I32, core::type::U32>()) {
AddError(source) << "input_attachment<type>: type must be f32, i32 or u32";
return false;
@@ -450,6 +456,29 @@
return true;
}
+bool Validator::InputAttachmentIndexAttribute(const ast::InputAttachmentIndexAttribute* attr,
+ const core::type::Type* type,
+ const Source& source) const {
+ if (!enabled_extensions_.Contains(wgsl::Extension::kChromiumInternalInputAttachments)) {
+ AddError(source) << "use of " << style::Attribute("@input_attachment_index")
+ << " requires enabling extension "
+ << style::Code("chromium_internal_input_attachments");
+ return false;
+ }
+
+ if (!type->Is<core::type::InputAttachment>()) {
+ std::string invalid_type = sem_.TypeNameOf(type);
+ AddError(source) << "cannot apply " << style::Attribute("@input_attachment_index")
+ << " to declaration of type " << style::Type(invalid_type);
+ AddNote(attr->source) << style::Attribute("@input_attachment_index")
+ << " must only be applied to declarations of "
+ << style::Type("input_attachment") << " type";
+ return false;
+ }
+
+ return true;
+}
+
bool Validator::Materialize(const core::type::Type* to,
const core::type::Type* from,
const Source& source) const {
@@ -708,6 +737,13 @@
return false;
}
+ auto* input_attachment_index_attr =
+ ast::GetAttribute<ast::InputAttachmentIndexAttribute>(decl->attributes);
+ if (input_attachment_index_attr &&
+ !InputAttachmentIndexAttribute(input_attachment_index_attr, global->Type()->UnwrapRef(),
+ decl->source)) {
+ return false;
+ }
switch (global->AddressSpace()) {
case core::AddressSpace::kUniform:
case core::AddressSpace::kStorage:
@@ -720,6 +756,13 @@
<< style::Attribute("@binding") << " attributes";
return false;
}
+ if (global->Type()->UnwrapRef()->Is<core::type::InputAttachment>() &&
+ !input_attachment_index_attr) {
+ AddError(decl->source)
+ << style::Type("input_attachment") << " variables require "
+ << style::Attribute("@input_attachment_index") << " attribute";
+ return false;
+ }
break;
}
default: {
diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h
index 6cdbfaf..fb54048 100644
--- a/src/tint/lang/wgsl/resolver/validator.h
+++ b/src/tint/lang/wgsl/resolver/validator.h
@@ -434,6 +434,15 @@
/// @returns true on success, false otherwise
bool InputAttachment(const core::type::InputAttachment* t, const Source& source) const;
+ /// Validates a input attachment index attribute
+ /// @param attr the input attachment index attribute to validate
+ /// @param type the variable type
+ /// @param source the source of declaration using the attribute
+ /// @returns true on success, false otherwise.
+ bool InputAttachmentIndexAttribute(const ast::InputAttachmentIndexAttribute* attr,
+ const core::type::Type* type,
+ const Source& source) const;
+
/// Validates a structure
/// @param str the structure to validate
/// @param stage the current pipeline stage
diff --git a/src/tint/lang/wgsl/sem/variable.h b/src/tint/lang/wgsl/sem/variable.h
index bf97929..51cac2a 100644
--- a/src/tint/lang/wgsl/sem/variable.h
+++ b/src/tint/lang/wgsl/sem/variable.h
@@ -169,6 +169,8 @@
/// @note a GlobalVariable generally doesn't have a `color` in WGSL, as it isn't allowed by
/// the spec. The location maybe attached by transforms such as CanonicalizeEntryPointIO.
std::optional<uint32_t> color;
+ /// The `input_attachment_index` attribute value for the variable, if set
+ std::optional<uint32_t> input_attachment_index;
};
/// GlobalVariable is a module-scope variable
diff --git a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
index b404397..7243f25 100644
--- a/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
+++ b/src/tint/lang/wgsl/writer/ir_to_program/ir_to_program.cc
@@ -1069,9 +1069,7 @@
/// name.
void Bind(const core::ir::Value* value, Symbol name) {
TINT_ASSERT(value);
-
- bool added = bindings_.Add(value, VariableValue{name});
- if (TINT_UNLIKELY(!added)) {
+ if (TINT_UNLIKELY(!bindings_.Add(value, VariableValue{name}))) {
TINT_ICE() << "Bind(" << value->TypeInfo().name << ") called twice for same value";
}
}
diff --git a/src/tint/utils/command/command_posix.cc b/src/tint/utils/command/command_posix.cc
index f3c30de..fae21be 100644
--- a/src/tint/utils/command/command_posix.cc
+++ b/src/tint/utils/command/command_posix.cc
@@ -75,7 +75,7 @@
operator int() { return handle_; }
/// @returns true if the file is not closed
- operator bool() { return handle_ != kClosed; }
+ explicit operator bool() { return handle_ != kClosed; }
private:
File(const File&) = delete;
@@ -103,7 +103,7 @@
}
/// @returns true if the pipe has an open read or write file
- operator bool() { return read || write; }
+ explicit operator bool() { return read || write; }
/// The reader end of the pipe
File read;
diff --git a/src/tint/utils/command/command_windows.cc b/src/tint/utils/command/command_windows.cc
index a6a577b..e591416 100644
--- a/src/tint/utils/command/command_windows.cc
+++ b/src/tint/utils/command/command_windows.cc
@@ -73,7 +73,7 @@
operator HANDLE() { return handle_; }
/// @returns true if the handle is not invalid
- operator bool() { return handle_ != nullptr; }
+ explicit operator bool() { return handle_ != nullptr; }
private:
Handle(const Handle&) = delete;
@@ -106,7 +106,7 @@
}
/// @returns true if the pipe has an open read or write file
- operator bool() { return read || write; }
+ explicit operator bool() { return read || write; }
/// The reader end of the pipe
Handle read;
diff --git a/src/tint/utils/containers/hashmap.h b/src/tint/utils/containers/hashmap.h
index 69790f3..df1ed08 100644
--- a/src/tint/utils/containers/hashmap.h
+++ b/src/tint/utils/containers/hashmap.h
@@ -90,7 +90,7 @@
T* value = nullptr;
/// @returns `true` if #value is not null.
- operator bool() const { return value; }
+ explicit operator bool() const { return value; }
/// @returns the dereferenced value, which must not be null.
T& operator*() const { return *value; }
@@ -129,7 +129,7 @@
bool added = false;
/// @returns #added
- operator bool() const { return added; }
+ explicit operator bool() const { return added; }
};
/// An unordered hashmap, with a fixed-size capacity that avoids heap allocations.
diff --git a/src/tint/utils/containers/hashmap_test.cc b/src/tint/utils/containers/hashmap_test.cc
index af926a2..f3b91bd 100644
--- a/src/tint/utils/containers/hashmap_test.cc
+++ b/src/tint/utils/containers/hashmap_test.cc
@@ -421,7 +421,7 @@
switch (rnd() % 7) {
case 0: { // Add
auto expected = reference.emplace(key, value).second;
- EXPECT_EQ(map.Add(key, value), expected) << "i:" << i;
+ EXPECT_EQ(map.Add(key, value).added, expected) << "i:" << i;
EXPECT_EQ(map.Get(key), value) << "i:" << i;
EXPECT_TRUE(map.Contains(key)) << "i:" << i;
break;
diff --git a/src/tint/utils/file/tmpfile.h b/src/tint/utils/file/tmpfile.h
index 0ae8e95..6bcd663 100644
--- a/src/tint/utils/file/tmpfile.h
+++ b/src/tint/utils/file/tmpfile.h
@@ -50,7 +50,7 @@
~TmpFile();
/// @return true if the temporary file was successfully created.
- operator bool() { return !path_.empty(); }
+ explicit operator bool() { return !path_.empty(); }
/// @return the path to the temporary file
std::string Path() const { return path_; }
diff --git a/src/tint/utils/id/generation_id.h b/src/tint/utils/id/generation_id.h
index 018b4a2..b3078df 100644
--- a/src/tint/utils/id/generation_id.h
+++ b/src/tint/utils/id/generation_id.h
@@ -67,7 +67,7 @@
uint32_t Value() const { return val; }
/// @returns true if this GenerationID is valid
- operator bool() const { return val != 0; }
+ explicit operator bool() const { return val != 0; }
private:
explicit GenerationID(uint32_t);
diff --git a/src/tint/utils/symbol/symbol.h b/src/tint/utils/symbol/symbol.h
index 9f26bb0..10b7bd7 100644
--- a/src/tint/utils/symbol/symbol.h
+++ b/src/tint/utils/symbol/symbol.h
@@ -83,7 +83,7 @@
bool IsValid() const { return val_ != static_cast<uint32_t>(-1); }
/// @returns true if the symbol is valid
- operator bool() const { return IsValid(); }
+ explicit operator bool() const { return IsValid(); }
/// @returns the value for the symbol
uint32_t value() const { return val_; }