[spirv-writer] Emit specialization constants.
This CL adds OpSpec constants to the SPIRV-Writer.
Bug: tint:151
Change-Id: I309013ca0b4cb514edd92fab3dab2e4faa15969a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/29101
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 6b4866d..c843386 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -355,7 +355,7 @@
uint32_t Builder::GenerateU32Literal(uint32_t val) {
ast::type::U32Type u32;
ast::SintLiteral lit(&u32, val);
- return GenerateLiteralIfNeeded(&lit);
+ return GenerateLiteralIfNeeded(nullptr, &lit);
}
bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) {
@@ -465,7 +465,7 @@
return GenerateCallExpression(expr->AsCall());
}
if (expr->IsConstructor()) {
- return GenerateConstructorExpression(expr->AsConstructor(), false);
+ return GenerateConstructorExpression(nullptr, expr->AsConstructor(), false);
}
if (expr->IsIdentifier()) {
return GenerateIdentifierExpression(expr->AsIdentifier());
@@ -611,7 +611,7 @@
// TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad.
ast::NullLiteral nl(var->type()->UnwrapPtrIfNeeded());
- auto null_id = GenerateLiteralIfNeeded(&nl);
+ auto null_id = GenerateLiteralIfNeeded(var, &nl);
if (null_id == 0) {
return 0;
}
@@ -642,8 +642,8 @@
return false;
}
- init_id = GenerateConstructorExpression(var->constructor()->AsConstructor(),
- true);
+ init_id = GenerateConstructorExpression(
+ var, var->constructor()->AsConstructor(), true);
if (init_id == 0) {
return false;
}
@@ -689,7 +689,7 @@
var->storage_class() == ast::StorageClass::kNone ||
var->storage_class() == ast::StorageClass::kOutput) {
ast::NullLiteral nl(var->type()->UnwrapPtrIfNeeded());
- init_id = GenerateLiteralIfNeeded(&nl);
+ init_id = GenerateLiteralIfNeeded(var, &nl);
if (init_id == 0) {
return 0;
}
@@ -718,6 +718,8 @@
spv::Op::OpDecorate,
{Operand::Int(var_id), Operand::Int(SpvDecorationDescriptorSet),
Operand::Int(deco->AsSet()->value())});
+ } else if (deco->IsConstantId()) {
+ // Spec constants are handled elsewhere
} else {
error_ = "unknown decoration";
return false;
@@ -1033,10 +1035,11 @@
}
uint32_t Builder::GenerateConstructorExpression(
+ ast::Variable* var,
ast::ConstructorExpression* expr,
bool is_global_init) {
if (expr->IsScalarConstructor()) {
- return GenerateLiteralIfNeeded(expr->AsScalarConstructor()->literal());
+ return GenerateLiteralIfNeeded(var, expr->AsScalarConstructor()->literal());
}
if (expr->IsTypeConstructor()) {
return GenerateTypeConstructorExpression(expr->AsTypeConstructor(),
@@ -1055,7 +1058,7 @@
// Generate the zero initializer if there are no values provided.
if (values.empty()) {
ast::NullLiteral nl(init->type()->UnwrapPtrIfNeeded());
- return GenerateLiteralIfNeeded(&nl);
+ return GenerateLiteralIfNeeded(nullptr, &nl);
}
std::ostringstream out;
@@ -1102,7 +1105,8 @@
for (const auto& e : values) {
uint32_t id = 0;
if (constructor_is_const) {
- id = GenerateConstructorExpression(e->AsConstructor(), is_global_init);
+ id = GenerateConstructorExpression(nullptr, e->AsConstructor(),
+ is_global_init);
} else {
id = GenerateExpression(e.get());
id = GenerateLoadIfNeeded(e->result_type(), id);
@@ -1268,12 +1272,21 @@
return result_id;
}
-uint32_t Builder::GenerateLiteralIfNeeded(ast::Literal* lit) {
+uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
+ ast::Literal* lit) {
auto type_id = GenerateTypeIfNeeded(lit->type());
if (type_id == 0) {
return 0;
}
+
auto name = lit->name();
+ bool is_spec_constant = false;
+ if (var && var->IsDecorated() &&
+ var->AsDecorated()->HasConstantIdDecoration()) {
+ name = "__spec" + name;
+ is_spec_constant = true;
+ }
+
auto val = const_to_id_.find(name);
if (val != const_to_id_.end()) {
return val->second;
@@ -1282,21 +1295,34 @@
auto result = result_op();
auto result_id = result.to_i();
+ if (is_spec_constant) {
+ push_annot(spv::Op::OpDecorate,
+ {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
+ Operand::Int(var->AsDecorated()->constant_id())});
+ }
+
if (lit->IsBool()) {
if (lit->AsBool()->IsTrue()) {
- push_type(spv::Op::OpConstantTrue, {Operand::Int(type_id), result});
+ push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue
+ : spv::Op::OpConstantTrue,
+ {Operand::Int(type_id), result});
} else {
- push_type(spv::Op::OpConstantFalse, {Operand::Int(type_id), result});
+ push_type(is_spec_constant ? spv::Op::OpSpecConstantFalse
+ : spv::Op::OpConstantFalse,
+ {Operand::Int(type_id), result});
}
} else if (lit->IsSint()) {
- push_type(spv::Op::OpConstant, {Operand::Int(type_id), result,
- Operand::Int(lit->AsSint()->value())});
+ push_type(
+ is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+ {Operand::Int(type_id), result, Operand::Int(lit->AsSint()->value())});
} else if (lit->IsUint()) {
- push_type(spv::Op::OpConstant, {Operand::Int(type_id), result,
- Operand::Int(lit->AsUint()->value())});
+ push_type(
+ is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+ {Operand::Int(type_id), result, Operand::Int(lit->AsUint()->value())});
} else if (lit->IsFloat()) {
- push_type(spv::Op::OpConstant, {Operand::Int(type_id), result,
- Operand::Float(lit->AsFloat()->value())});
+ push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+ {Operand::Int(type_id), result,
+ Operand::Float(lit->AsFloat()->value())});
} else if (lit->IsNull()) {
push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
} else {
@@ -1730,7 +1756,8 @@
spirv_params.push_back(Operand::Int(SpvImageOperandsLodMask));
ast::type::F32Type f32;
ast::FloatLiteral float_0(&f32, 0.0);
- spirv_params.push_back(Operand::Int(GenerateLiteralIfNeeded(&float_0)));
+ spirv_params.push_back(
+ Operand::Int(GenerateLiteralIfNeeded(nullptr, &float_0)));
}
if (op == spv::Op::OpNop) {
error_ = "unable to determine operator for: " + ident->name();
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index ccfce39..b6e3f7c 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -249,10 +249,12 @@
/// Generates an import instruction
void GenerateGLSLstd450Import();
/// Generates a constructor expression
+ /// @param var the variable generated for, nullptr if no variable associated.
/// @param expr the expression to generate
/// @param is_global_init set true if this is a global variable constructor
/// @returns the ID of the expression or 0 on failure.
- uint32_t GenerateConstructorExpression(ast::ConstructorExpression* expr,
+ uint32_t GenerateConstructorExpression(ast::Variable* var,
+ ast::ConstructorExpression* expr,
bool is_global_init);
/// Generates a type constructor expression
/// @param init the expression to generate
@@ -262,9 +264,10 @@
ast::TypeConstructorExpression* init,
bool is_global_init);
/// Generates a literal constant if needed
+ /// @param var the variable generated for, nullptr if no variable associated.
/// @param lit the literal to generate
/// @returns the ID on success or 0 on failure
- uint32_t GenerateLiteralIfNeeded(ast::Literal* lit);
+ uint32_t GenerateLiteralIfNeeded(ast::Variable* var, ast::Literal* lit);
/// Generates a binary expression
/// @param expr the expression to generate
/// @returns the expression ID on success or 0 otherwise
diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc
index 73f07ad..12ac999 100644
--- a/src/writer/spirv/builder_constructor_expression_test.cc
+++ b/src/writer/spirv/builder_constructor_expression_test.cc
@@ -56,7 +56,7 @@
ast::Module mod;
Builder b(&mod);
- EXPECT_EQ(b.GenerateConstructorExpression(&c, true), 2u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &c, true), 2u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
@@ -84,7 +84,7 @@
EXPECT_TRUE(td.DetermineResultType(&t)) << td.error();
Builder b(&mod);
- EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &t, true), 5u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
@@ -190,7 +190,7 @@
ast::Module mod;
Builder b(&mod);
- EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 0u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &t, true), 0u);
EXPECT_TRUE(b.has_error());
EXPECT_EQ(b.error(), R"(constructor must be a constant expression)");
}
@@ -765,7 +765,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
@@ -807,7 +807,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
@@ -851,7 +851,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 4
@@ -895,7 +895,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 4
@@ -939,7 +939,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 11u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 11u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 4
@@ -987,7 +987,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 4
@@ -1033,7 +1033,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 4
@@ -1079,7 +1079,7 @@
Builder b(&mod);
b.push_function(Function{});
- EXPECT_EQ(b.GenerateConstructorExpression(&cast, true), 13u);
+ EXPECT_EQ(b.GenerateConstructorExpression(nullptr, &cast, true), 13u);
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 4
diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc
index 1e60c0f..73ae303 100644
--- a/src/writer/spirv/builder_global_variable_test.cc
+++ b/src/writer/spirv/builder_global_variable_test.cc
@@ -16,14 +16,18 @@
#include "gtest/gtest.h"
#include "src/ast/binding_decoration.h"
+#include "src/ast/bool_literal.h"
#include "src/ast/builtin.h"
#include "src/ast/builtin_decoration.h"
+#include "src/ast/constant_id_decoration.h"
#include "src/ast/decorated_variable.h"
#include "src/ast/float_literal.h"
#include "src/ast/location_decoration.h"
+#include "src/ast/module.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/set_decoration.h"
#include "src/ast/storage_class.h"
+#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h"
@@ -359,6 +363,58 @@
)");
}
+TEST_F(BuilderTest, GlobalVar_ConstantId_Bool) {
+ ast::type::BoolType bool_type;
+
+ ast::VariableDecorationList decos;
+ decos.push_back(std::make_unique<ast::ConstantIdDecoration>(1200));
+
+ ast::DecoratedVariable v(std::make_unique<ast::Variable>(
+ "var", ast::StorageClass::kNone, &bool_type));
+ v.set_decorations(std::move(decos));
+ v.set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::BoolLiteral>(&bool_type, true)));
+
+ ast::Module mod;
+ Builder b(&mod);
+ EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var"
+)");
+ EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 1200
+)");
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
+%2 = OpSpecConstantTrue %1
+%4 = OpTypePointer Private %1
+%3 = OpVariable %4 Private %2
+)");
+}
+
+TEST_F(BuilderTest, GlobalVar_ConstantId_Scalar) {
+ ast::type::F32Type f32;
+
+ ast::VariableDecorationList decos;
+ decos.push_back(std::make_unique<ast::ConstantIdDecoration>(0));
+
+ ast::DecoratedVariable v(
+ std::make_unique<ast::Variable>("var", ast::StorageClass::kNone, &f32));
+ v.set_decorations(std::move(decos));
+ v.set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 2.0)));
+
+ ast::Module mod;
+ Builder b(&mod);
+ EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error();
+ EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "var"
+)");
+ EXPECT_EQ(DumpInstructions(b.annots()), R"(OpDecorate %2 SpecId 0
+)");
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
+%2 = OpSpecConstant %1 2
+%4 = OpTypePointer Private %1
+%3 = OpVariable %4 Private %2
+)");
+}
+
struct BuiltinData {
ast::Builtin builtin;
SpvBuiltIn result;
diff --git a/src/writer/spirv/builder_literal_test.cc b/src/writer/spirv/builder_literal_test.cc
index 1c96661..294c6b6 100644
--- a/src/writer/spirv/builder_literal_test.cc
+++ b/src/writer/spirv/builder_literal_test.cc
@@ -38,7 +38,7 @@
ast::Module mod;
Builder b(&mod);
- auto id = b.GenerateLiteralIfNeeded(&b_true);
+ auto id = b.GenerateLiteralIfNeeded(nullptr, &b_true);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -53,7 +53,7 @@
ast::Module mod;
Builder b(&mod);
- auto id = b.GenerateLiteralIfNeeded(&b_false);
+ auto id = b.GenerateLiteralIfNeeded(nullptr, &b_false);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -69,11 +69,11 @@
ast::Module mod;
Builder b(&mod);
- ASSERT_NE(b.GenerateLiteralIfNeeded(&b_true), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_true), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
- ASSERT_NE(b.GenerateLiteralIfNeeded(&b_false), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_false), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
- ASSERT_NE(b.GenerateLiteralIfNeeded(&b_true), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &b_true), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
@@ -88,7 +88,7 @@
ast::Module mod;
Builder b(&mod);
- auto id = b.GenerateLiteralIfNeeded(&i);
+ auto id = b.GenerateLiteralIfNeeded(nullptr, &i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -104,8 +104,8 @@
ast::Module mod;
Builder b(&mod);
- ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u);
- ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
@@ -119,7 +119,7 @@
ast::Module mod;
Builder b(&mod);
- auto id = b.GenerateLiteralIfNeeded(&i);
+ auto id = b.GenerateLiteralIfNeeded(nullptr, &i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -135,8 +135,8 @@
ast::Module mod;
Builder b(&mod);
- ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u);
- ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
@@ -150,7 +150,7 @@
ast::Module mod;
Builder b(&mod);
- auto id = b.GenerateLiteralIfNeeded(&i);
+ auto id = b.GenerateLiteralIfNeeded(nullptr, &i);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(2u, id);
@@ -166,8 +166,8 @@
ast::Module mod;
Builder b(&mod);
- ASSERT_NE(b.GenerateLiteralIfNeeded(&i1), 0u);
- ASSERT_NE(b.GenerateLiteralIfNeeded(&i2), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i1), 0u);
+ ASSERT_NE(b.GenerateLiteralIfNeeded(nullptr, &i2), 0u);
ASSERT_FALSE(b.has_error()) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32