writer/spirv: Clean up constant generation
Requiring a temporary stack-allocated ast::Literal is an unpleasant requirement to generate a SPIR-V constant value.
GenerateU32Literal() was also creating an invalid AST - the type was U32, yet an an ast::SintLiteral was used.
Instead add Constant for holding a constant value, and use this as the map key.
This also removes the last remaining use of ast::NullLiteral, which will be removed in the next change.
Change-Id: Ia85732784075f153503dbef101ba95018eaa4bf5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45342
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index 787afe2..fc82118 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -558,6 +558,7 @@
"writer/spirv/instruction.h",
"writer/spirv/operand.cc",
"writer/spirv/operand.h",
+ "writer/spirv/scalar_constant.h",
]
configs += [ ":tint_common_config" ]
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 9614ed4..9e5f29e 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -343,6 +343,7 @@
writer/spirv/instruction.h
writer/spirv/operand.cc
writer/spirv/operand.h
+ writer/spirv/scalar_constant.h
)
endif()
@@ -678,6 +679,7 @@
writer/spirv/builder_unary_op_expression_test.cc
writer/spirv/instruction_test.cc
writer/spirv/operand_test.cc
+ writer/spirv/scalar_constant_test.cc
writer/spirv/spv_dump.cc
writer/spirv/spv_dump.h
writer/spirv/test_helper.h
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 71054c8..7209aba 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -21,7 +21,6 @@
#include "src/ast/call_statement.h"
#include "src/ast/constant_id_decoration.h"
#include "src/ast/fallthrough_statement.h"
-#include "src/ast/null_literal.h"
#include "src/semantic/array.h"
#include "src/semantic/call.h"
#include "src/semantic/function.h"
@@ -365,12 +364,6 @@
return true;
}
-uint32_t Builder::GenerateU32Literal(uint32_t val) {
- type::U32 u32;
- ast::SintLiteral lit(Source{}, &u32, val);
- return GenerateLiteralIfNeeded(nullptr, &lit);
-}
-
bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) {
auto lhs_id = GenerateExpression(assign->lhs());
if (lhs_id == 0) {
@@ -648,8 +641,7 @@
// TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad.
- ast::NullLiteral nl(Source{}, var->type()->UnwrapPtrIfNeeded());
- auto null_id = GenerateLiteralIfNeeded(var, &nl);
+ auto null_id = GenerateConstantNullIfNeeded(var->type()->UnwrapPtrIfNeeded());
if (null_id == 0) {
return 0;
}
@@ -779,8 +771,7 @@
} else if (sem->StorageClass() == ast::StorageClass::kPrivate ||
sem->StorageClass() == ast::StorageClass::kNone ||
sem->StorageClass() == ast::StorageClass::kOutput) {
- ast::NullLiteral nl(Source{}, type);
- init_id = GenerateLiteralIfNeeded(var, &nl);
+ init_id = GenerateConstantNullIfNeeded(type);
if (init_id == 0) {
return 0;
}
@@ -888,7 +879,7 @@
}
}
- auto idx_id = GenerateU32Literal(i);
+ auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
if (idx_id == 0) {
return 0;
}
@@ -913,7 +904,7 @@
}
if (info->source_type->Is<type::Pointer>()) {
- auto idx_id = GenerateU32Literal(val);
+ auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(val));
if (idx_id == 0) {
return 0;
}
@@ -1044,8 +1035,7 @@
auto ary_result = result_op();
- ast::NullLiteral nl(Source{}, ary_res_type);
- auto init = GenerateLiteralIfNeeded(nullptr, &nl);
+ auto init = GenerateConstantNullIfNeeded(ary_res_type);
// If we're access chaining into an array then we must be in a function
push_function_var(
@@ -1259,8 +1249,7 @@
// Generate the zero initializer if there are no values provided.
if (values.empty()) {
- ast::NullLiteral nl(Source{}, init->type()->UnwrapPtrIfNeeded());
- return GenerateLiteralIfNeeded(nullptr, &nl);
+ return GenerateConstantNullIfNeeded(init->type()->UnwrapPtrIfNeeded());
}
std::ostringstream out;
@@ -1370,7 +1359,7 @@
result_is_constant_composite = false;
} else {
// A global initializer, must use OpSpecConstantOp. Case 1.
- auto idx_id = GenerateU32Literal(i);
+ auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
if (idx_id == 0) {
return 0;
}
@@ -1392,8 +1381,8 @@
}
auto str = out.str();
- auto val = const_to_id_.find(str);
- if (val != const_to_id_.end()) {
+ auto val = type_constructor_to_id_.find(str);
+ if (val != type_constructor_to_id_.end()) {
return val->second;
}
@@ -1401,7 +1390,7 @@
ops.insert(ops.begin(), result);
ops.insert(ops.begin(), Operand::Int(type_id));
- const_to_id_[str] = result.to_i();
+ type_constructor_to_id_[str] = result.to_i();
if (result_is_spec_composite) {
push_type(spv::Op::OpSpecConstantComposite, ops);
@@ -1480,59 +1469,133 @@
uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
ast::Literal* lit) {
- auto type_id = GenerateTypeIfNeeded(lit->type());
- if (type_id == 0) {
- return 0;
- }
+ ScalarConstant constant;
- auto name = lit->name();
- bool is_spec_constant = false;
if (var && var->HasConstantIdDecoration()) {
- name = "__spec" + name;
- is_spec_constant = true;
- }
-
- auto val = const_to_id_.find(name);
- if (val != const_to_id_.end()) {
- return val->second;
- }
-
- auto result = result_op();
- auto result_id = result.to_i();
-
- if (is_spec_constant) {
- push_annot(spv::Op::OpDecorate,
- {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
- Operand::Int(var->constant_id())});
+ constant.is_spec_op = true;
+ constant.constant_id = var->constant_id();
}
if (auto* l = lit->As<ast::BoolLiteral>()) {
- if (l->IsTrue()) {
- push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue
- : spv::Op::OpConstantTrue,
- {Operand::Int(type_id), result});
- } else {
- push_type(is_spec_constant ? spv::Op::OpSpecConstantFalse
- : spv::Op::OpConstantFalse,
- {Operand::Int(type_id), result});
- }
+ constant.kind = ScalarConstant::Kind::kBool;
+ constant.value.b = l->IsTrue();
} else if (auto* sl = lit->As<ast::SintLiteral>()) {
- push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
- {Operand::Int(type_id), result, Operand::Int(sl->value())});
+ constant.kind = ScalarConstant::Kind::kI32;
+ constant.value.i32 = sl->value();
} else if (auto* ul = lit->As<ast::UintLiteral>()) {
- push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
- {Operand::Int(type_id), result, Operand::Int(ul->value())});
+ constant.kind = ScalarConstant::Kind::kU32;
+ constant.value.u32 = ul->value();
} else if (auto* fl = lit->As<ast::FloatLiteral>()) {
- push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
- {Operand::Int(type_id), result, Operand::Float(fl->value())});
- } else if (lit->Is<ast::NullLiteral>()) {
- push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
+ constant.kind = ScalarConstant::Kind::kF32;
+ constant.value.f32 = fl->value();
} else {
error_ = "unknown literal type";
return 0;
}
- const_to_id_[name] = result_id;
+ return GenerateConstantIfNeeded(constant);
+}
+
+uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
+ auto it = const_to_id_.find(constant);
+ if (it != const_to_id_.end()) {
+ return it->second;
+ }
+
+ uint32_t type_id = 0;
+
+ switch (constant.kind) {
+ case ScalarConstant::Kind::kU32: {
+ type::U32 u32;
+ type_id = GenerateTypeIfNeeded(&u32);
+ break;
+ }
+ case ScalarConstant::Kind::kI32: {
+ type::I32 i32;
+ type_id = GenerateTypeIfNeeded(&i32);
+ break;
+ }
+ case ScalarConstant::Kind::kF32: {
+ type::F32 f32;
+ type_id = GenerateTypeIfNeeded(&f32);
+ break;
+ }
+ case ScalarConstant::Kind::kBool: {
+ type::Bool bool_;
+ type_id = GenerateTypeIfNeeded(&bool_);
+ break;
+ }
+ }
+
+ if (type_id == 0) {
+ return 0;
+ }
+
+ auto result = result_op();
+ auto result_id = result.to_i();
+
+ if (constant.is_spec_op) {
+ push_annot(spv::Op::OpDecorate,
+ {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
+ Operand::Int(constant.constant_id)});
+ }
+
+ switch (constant.kind) {
+ case ScalarConstant::Kind::kU32: {
+ push_type(
+ constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+ {Operand::Int(type_id), result, Operand::Int(constant.value.u32)});
+ break;
+ }
+ case ScalarConstant::Kind::kI32: {
+ push_type(
+ constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+ {Operand::Int(type_id), result, Operand::Int(constant.value.i32)});
+ break;
+ }
+ case ScalarConstant::Kind::kF32: {
+ push_type(
+ constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+ {Operand::Int(type_id), result, Operand::Float(constant.value.f32)});
+ break;
+ }
+ case ScalarConstant::Kind::kBool: {
+ if (constant.value.b) {
+ push_type(constant.is_spec_op ? spv::Op::OpSpecConstantTrue
+ : spv::Op::OpConstantTrue,
+ {Operand::Int(type_id), result});
+ } else {
+ push_type(constant.is_spec_op ? spv::Op::OpSpecConstantFalse
+ : spv::Op::OpConstantFalse,
+ {Operand::Int(type_id), result});
+ }
+ break;
+ }
+ }
+
+ const_to_id_[constant] = result_id;
+ return result_id;
+}
+
+uint32_t Builder::GenerateConstantNullIfNeeded(type::Type* type) {
+ auto type_id = GenerateTypeIfNeeded(type);
+ if (type_id == 0) {
+ return 0;
+ }
+
+ auto name = type->type_name();
+
+ auto it = const_null_to_id_.find(name);
+ if (it != const_null_to_id_.end()) {
+ return it->second;
+ }
+
+ auto result = result_op();
+ auto result_id = result.to_i();
+
+ push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
+
+ const_null_to_id_[name] = result_id;
return result_id;
}
@@ -2955,7 +3018,7 @@
if (ary->IsRuntimeArray()) {
push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)});
} else {
- auto len_id = GenerateU32Literal(ary->size());
+ auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->size()));
if (len_id == 0) {
return false;
}
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 4f3251d..6147c9f 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -37,6 +37,7 @@
#include "src/type/access_control_type.h"
#include "src/type/storage_texture_type.h"
#include "src/writer/spirv/function.h"
+#include "src/writer/spirv/scalar_constant.h"
namespace tint {
@@ -208,10 +209,6 @@
/// @param id the id to use for the label
/// @returns true on success.
bool GenerateLabel(uint32_t id);
- /// Generates a uint32_t literal.
- /// @param val the value to generate
- /// @returns the ID of the generated literal
- uint32_t GenerateU32Literal(uint32_t val);
/// Generates an assignment statement
/// @param assign the statement to generate
/// @returns true if the statement was successfully generated
@@ -486,6 +483,16 @@
return builder_.TypeOf(expr);
}
+ /// Generates a constant if needed
+ /// @param constant the constant to generate.
+ /// @returns the ID on success or 0 on failure
+ uint32_t GenerateConstantIfNeeded(const ScalarConstant& constant);
+
+ /// Generates a constant-null of the given type, if needed
+ /// @param type the type of the constant null to generate.
+ /// @returns the ID on success or 0 on failure
+ uint32_t GenerateConstantNullIfNeeded(type::Type* type);
+
ProgramBuilder builder_;
std::string error_;
uint32_t next_id_ = 1;
@@ -504,7 +511,9 @@
std::unordered_map<std::string, uint32_t> import_name_to_id_;
std::unordered_map<Symbol, uint32_t> func_symbol_to_id_;
std::unordered_map<std::string, uint32_t> type_name_to_id_;
- std::unordered_map<std::string, uint32_t> const_to_id_;
+ std::unordered_map<ScalarConstant, uint32_t> const_to_id_;
+ std::unordered_map<std::string, uint32_t> type_constructor_to_id_;
+ std::unordered_map<std::string, uint32_t> const_null_to_id_;
std::unordered_map<std::string, uint32_t>
texture_type_name_to_sampled_image_type_id_;
ScopeStack<uint32_t> scope_stack_;
diff --git a/src/writer/spirv/scalar_constant.h b/src/writer/spirv/scalar_constant.h
new file mode 100644
index 0000000..18dba91
--- /dev/null
+++ b/src/writer/spirv/scalar_constant.h
@@ -0,0 +1,115 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
+#define SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
+
+#include <stdint.h>
+
+#include <cstring>
+#include <functional>
+
+namespace tint {
+
+// Forward declarations
+namespace semantic {
+class Call;
+} // namespace semantic
+
+namespace writer {
+namespace spirv {
+
+/// ScalarConstant represents a scalar constant value
+struct ScalarConstant {
+ /// The constant value
+ union Value {
+ /// The value as a bool
+ bool b;
+ /// The value as a uint32_t
+ uint32_t u32;
+ /// The value as a int32_t
+ int32_t i32;
+ /// The value as a float
+ float f32;
+
+ /// The value that is wide enough to encompass all other types (including
+ /// future 64-bit data types).
+ uint64_t u64;
+ };
+
+ /// The kind of constant
+ enum class Kind { kBool, kU32, kI32, kF32 };
+
+ /// Constructor
+ inline ScalarConstant() { value.u64 = 0; }
+
+ /// @param value the value of the constant
+ /// @returns a new ScalarConstant with the provided value and kind Kind::kU32
+ static inline ScalarConstant U32(uint32_t value) {
+ ScalarConstant c;
+ c.value.u32 = value;
+ c.kind = Kind::kU32;
+ return c;
+ }
+
+ /// Equality operator
+ /// @param rhs the ScalarConstant to compare against
+ /// @returns true if this ScalarConstant is equal to `rhs`
+ inline bool operator==(const ScalarConstant& rhs) const {
+ return value.u64 == rhs.value.u64 && kind == rhs.kind &&
+ is_spec_op == rhs.is_spec_op && constant_id == rhs.constant_id;
+ }
+
+ /// Inequality operator
+ /// @param rhs the ScalarConstant to compare against
+ /// @returns true if this ScalarConstant is not equal to `rhs`
+ inline bool operator!=(const ScalarConstant& rhs) const {
+ return !(*this == rhs);
+ }
+
+ /// The constant value
+ Value value;
+ /// The constant value kind
+ Kind kind = Kind::kBool;
+ /// True if the constant is a specialization op
+ bool is_spec_op = false;
+ /// The identifier if a specialization op
+ uint32_t constant_id = 0;
+};
+
+} // namespace spirv
+} // namespace writer
+} // namespace tint
+
+namespace std {
+
+/// Custom std::hash specialization for tint::Symbol so symbols can be used as
+/// keys for std::unordered_map and std::unordered_set.
+template <>
+class hash<tint::writer::spirv::ScalarConstant> {
+ public:
+ /// @param c the ScalarConstant
+ /// @return the Symbol internal value
+ inline std::size_t operator()(
+ const tint::writer::spirv::ScalarConstant& c) const {
+ uint32_t value = 0;
+ std::memcpy(&value, &c.value, sizeof(value));
+ return (static_cast<std::size_t>(value) << 2) |
+ (static_cast<std::size_t>(c.kind) & 3);
+ }
+};
+
+} // namespace std
+
+#endif // SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
diff --git a/src/writer/spirv/scalar_constant_test.cc b/src/writer/spirv/scalar_constant_test.cc
new file mode 100644
index 0000000..b514146
--- /dev/null
+++ b/src/writer/spirv/scalar_constant_test.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/writer/spirv/scalar_constant.h"
+#include "src/writer/spirv/test_helper.h"
+
+namespace tint {
+namespace writer {
+namespace spirv {
+namespace {
+
+using SpirvScalarConstantTest = TestHelper;
+
+TEST_F(SpirvScalarConstantTest, Equality) {
+ ScalarConstant a{};
+ ScalarConstant b{};
+ EXPECT_EQ(a, b);
+
+ a.kind = ScalarConstant::Kind::kU32;
+ EXPECT_NE(a, b);
+ b.kind = ScalarConstant::Kind::kU32;
+ EXPECT_EQ(a, b);
+
+ a.value.b = true;
+ EXPECT_NE(a, b);
+ b.value.b = true;
+ EXPECT_EQ(a, b);
+
+ a.is_spec_op = true;
+ EXPECT_NE(a, b);
+ b.is_spec_op = true;
+ EXPECT_EQ(a, b);
+
+ a.constant_id = 3;
+ EXPECT_NE(a, b);
+ b.constant_id = 3;
+ EXPECT_EQ(a, b);
+}
+
+TEST_F(SpirvScalarConstantTest, U32) {
+ auto c = ScalarConstant::U32(123);
+ EXPECT_EQ(c.value.u32, 123u);
+ EXPECT_EQ(c.kind, ScalarConstant::Kind::kU32);
+}
+
+} // namespace
+} // namespace spirv
+} // namespace writer
+} // namespace tint
diff --git a/test/BUILD.gn b/test/BUILD.gn
index b49cf59..b69dd73 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -336,6 +336,7 @@
"../src/writer/spirv/builder_unary_op_expression_test.cc",
"../src/writer/spirv/instruction_test.cc",
"../src/writer/spirv/operand_test.cc",
+ "../src/writer/spirv/scalar_constant_test.cc",
"../src/writer/spirv/spv_dump.cc",
"../src/writer/spirv/spv_dump.h",
"../src/writer/spirv/test_helper.h",