ast: Remove TypeConstructorExpression
Add a new 'Target' to the ast::CallExpression, which can be either an
Identifier or Type. The Identifier may resolve to a Type, if the Type is
a structure or alias.
The Resolver now resolves the CallExpression target to one of the
following sem::CallTargets:
* sem::Function
* sem::Intrinsic
* sem::TypeConstructor
* sem::TypeCast
This change will allow us to remove the type tracking logic from the WGSL
parser, which is required for out-of-order module scope declarations.
Bug: tint:888
Bug: tint:1266
Change-Id: I696f117115a50981fd5c102a0d7764641bb755dd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68525
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/src/BUILD.gn b/src/BUILD.gn
index bc38876..8140244 100644
--- a/src/BUILD.gn
+++ b/src/BUILD.gn
@@ -314,8 +314,6 @@
"ast/texture.h",
"ast/traverse_expressions.h",
"ast/type.h",
- "ast/type_constructor_expression.cc",
- "ast/type_constructor_expression.h",
"ast/type_decl.cc",
"ast/type_decl.h",
"ast/type_name.cc",
@@ -408,8 +406,8 @@
"sem/storage_texture_type.h",
"sem/switch_statement.h",
"sem/texture_type.h",
- "sem/type_cast.h",
"sem/type_constructor.h",
+ "sem/type_conversion.h",
"sem/type.h",
"sem/type_manager.h",
"sem/type_mappings.h",
@@ -576,10 +574,10 @@
"sem/switch_statement.h",
"sem/texture_type.cc",
"sem/texture_type.h",
- "sem/type_cast.cc",
- "sem/type_cast.h",
"sem/type_constructor.cc",
"sem/type_constructor.h",
+ "sem/type_conversion.cc",
+ "sem/type_conversion.h",
"sem/type.cc",
"sem/type.h",
"sem/type_manager.cc",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 06758d8..866edf9 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -177,8 +177,6 @@
ast/texture.cc
ast/texture.h
ast/traverse_expressions.h
- ast/type_constructor_expression.cc
- ast/type_constructor_expression.h
ast/type_name.cc
ast/type_name.h
ast/ast_type.cc # TODO(bclayton) - rename to type.cc
@@ -379,10 +377,10 @@
sem/switch_statement.h
sem/texture_type.cc
sem/texture_type.h
- sem/type_cast.cc
- sem/type_cast.h
sem/type_constructor.cc
sem/type_constructor.h
+ sem/type_conversion.cc
+ sem/type_conversion.h
sem/type.cc
sem/type.h
sem/type_manager.cc
@@ -644,7 +642,6 @@
ast/test_helper.h
ast/texture_test.cc
ast/traverse_expressions_test.cc
- ast/type_constructor_expression_test.cc
ast/u32_test.cc
ast/uint_literal_expression_test.cc
ast/unary_op_expression_test.cc
diff --git a/src/ast/call_expression.cc b/src/ast/call_expression.cc
index d4ffa6c..c3fc629 100644
--- a/src/ast/call_expression.cc
+++ b/src/ast/call_expression.cc
@@ -21,13 +21,39 @@
namespace tint {
namespace ast {
+namespace {
+CallExpression::Target ToTarget(const IdentifierExpression* name) {
+ CallExpression::Target target;
+ target.name = name;
+ return target;
+}
+CallExpression::Target ToTarget(const Type* type) {
+ CallExpression::Target target;
+ target.type = type;
+ return target;
+}
+} // namespace
+
CallExpression::CallExpression(ProgramID pid,
const Source& src,
- const IdentifierExpression* fn,
+ const IdentifierExpression* name,
ExpressionList a)
- : Base(pid, src), func(fn), args(a) {
- TINT_ASSERT(AST, func);
- TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, func, program_id);
+ : Base(pid, src), target(ToTarget(name)), args(a) {
+ TINT_ASSERT(AST, name);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, name, program_id);
+ for (auto* arg : args) {
+ TINT_ASSERT(AST, arg);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, arg, program_id);
+ }
+}
+
+CallExpression::CallExpression(ProgramID pid,
+ const Source& src,
+ const Type* type,
+ ExpressionList a)
+ : Base(pid, src), target(ToTarget(type)), args(a) {
+ TINT_ASSERT(AST, type);
+ TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, type, program_id);
for (auto* arg : args) {
TINT_ASSERT(AST, arg);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, arg, program_id);
@@ -41,9 +67,11 @@
const CallExpression* CallExpression::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
- auto* fn = ctx->Clone(func);
auto p = ctx->Clone(args);
- return ctx->dst->create<CallExpression>(src, fn, p);
+ return target.name
+ ? ctx->dst->create<CallExpression>(src, ctx->Clone(target.name), p)
+ : ctx->dst->create<CallExpression>(src, ctx->Clone(target.type),
+ p);
}
} // namespace ast
diff --git a/src/ast/call_expression.h b/src/ast/call_expression.h
index 4bd6fc9..68ef6cc 100644
--- a/src/ast/call_expression.h
+++ b/src/ast/call_expression.h
@@ -21,20 +21,36 @@
namespace ast {
// Forward declarations.
+class Type;
class IdentifierExpression;
-/// A call expression
+/// A call expression - represents either a:
+/// * sem::Function
+/// * sem::Intrinsic
+/// * sem::TypeConstructor
+/// * sem::TypeConversion
class CallExpression : public Castable<CallExpression, Expression> {
public:
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the call expression source
- /// @param func the function
+ /// @param name the function or type name
/// @param args the arguments
CallExpression(ProgramID program_id,
const Source& source,
- const IdentifierExpression* func,
+ const IdentifierExpression* name,
ExpressionList args);
+
+ /// Constructor
+ /// @param program_id the identifier of the program that owns this node
+ /// @param source the call expression source
+ /// @param type the type
+ /// @param args the arguments
+ CallExpression(ProgramID program_id,
+ const Source& source,
+ const Type* type,
+ ExpressionList args);
+
/// Move constructor
CallExpression(CallExpression&&);
~CallExpression() override;
@@ -45,8 +61,19 @@
/// @return the newly cloned node
const CallExpression* Clone(CloneContext* ctx) const override;
+ /// Target is either an identifier, or a Type.
+ /// One of these must be nullptr and the other a non-nullptr.
+ struct Target {
+ /// name is a function or intrinsic to call, or type name to construct or
+ /// cast-to
+ const IdentifierExpression* name = nullptr;
+ /// type to construct or cast-to
+ const Type* type = nullptr;
+ };
+
/// The target function
- const IdentifierExpression* const func;
+ const Target target;
+
/// The arguments
const ExpressionList args;
};
diff --git a/src/ast/call_expression_test.cc b/src/ast/call_expression_test.cc
index af250c0..ea91b25 100644
--- a/src/ast/call_expression_test.cc
+++ b/src/ast/call_expression_test.cc
@@ -21,14 +21,15 @@
using CallExpressionTest = TestHelper;
-TEST_F(CallExpressionTest, Creation) {
+TEST_F(CallExpressionTest, CreationIdentifier) {
auto* func = Expr("func");
ExpressionList params;
params.push_back(Expr("param1"));
params.push_back(Expr("param2"));
auto* stmt = create<CallExpression>(func, params);
- EXPECT_EQ(stmt->func, func);
+ EXPECT_EQ(stmt->target.name, func);
+ EXPECT_EQ(stmt->target.type, nullptr);
const auto& vec = stmt->args;
ASSERT_EQ(vec.size(), 2u);
@@ -36,10 +37,39 @@
EXPECT_EQ(vec[1], params[1]);
}
-TEST_F(CallExpressionTest, Creation_WithSource) {
+TEST_F(CallExpressionTest, CreationIdentifier_WithSource) {
auto* func = Expr("func");
- auto* stmt = create<CallExpression>(Source{Source::Location{20, 2}}, func,
- ExpressionList{});
+ auto* stmt = create<CallExpression>(Source{{20, 2}}, func, ExpressionList{});
+ EXPECT_EQ(stmt->target.name, func);
+ EXPECT_EQ(stmt->target.type, nullptr);
+
+ auto src = stmt->source;
+ EXPECT_EQ(src.range.begin.line, 20u);
+ EXPECT_EQ(src.range.begin.column, 2u);
+}
+
+TEST_F(CallExpressionTest, CreationType) {
+ auto* type = ty.f32();
+ ExpressionList params;
+ params.push_back(Expr("param1"));
+ params.push_back(Expr("param2"));
+
+ auto* stmt = create<CallExpression>(type, params);
+ EXPECT_EQ(stmt->target.name, nullptr);
+ EXPECT_EQ(stmt->target.type, type);
+
+ const auto& vec = stmt->args;
+ ASSERT_EQ(vec.size(), 2u);
+ EXPECT_EQ(vec[0], params[0]);
+ EXPECT_EQ(vec[1], params[1]);
+}
+
+TEST_F(CallExpressionTest, CreationType_WithSource) {
+ auto* type = ty.f32();
+ auto* stmt = create<CallExpression>(Source{{20, 2}}, type, ExpressionList{});
+ EXPECT_EQ(stmt->target.name, nullptr);
+ EXPECT_EQ(stmt->target.type, type);
+
auto src = stmt->source;
EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u);
@@ -51,11 +81,21 @@
EXPECT_TRUE(stmt->Is<CallExpression>());
}
-TEST_F(CallExpressionTest, Assert_Null_Function) {
+TEST_F(CallExpressionTest, Assert_Null_Identifier) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
- b.create<CallExpression>(nullptr, ExpressionList{});
+ b.create<CallExpression>(static_cast<IdentifierExpression*>(nullptr),
+ ExpressionList{});
+ },
+ "internal compiler error");
+}
+
+TEST_F(CallExpressionTest, Assert_Null_Type) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b;
+ b.create<CallExpression>(static_cast<Type*>(nullptr), ExpressionList{});
},
"internal compiler error");
}
@@ -73,7 +113,7 @@
"internal compiler error");
}
-TEST_F(CallExpressionTest, Assert_DifferentProgramID_Function) {
+TEST_F(CallExpressionTest, Assert_DifferentProgramID_Identifier) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b1;
@@ -83,6 +123,16 @@
"internal compiler error");
}
+TEST_F(CallExpressionTest, Assert_DifferentProgramID_Type) {
+ EXPECT_FATAL_FAILURE(
+ {
+ ProgramBuilder b1;
+ ProgramBuilder b2;
+ b1.create<CallExpression>(b2.ty.f32(), ExpressionList{});
+ },
+ "internal compiler error");
+}
+
TEST_F(CallExpressionTest, Assert_DifferentProgramID_Param) {
EXPECT_FATAL_FAILURE(
{
diff --git a/src/ast/traverse_expressions.h b/src/ast/traverse_expressions.h
index 2792743..88d3dfc 100644
--- a/src/ast/traverse_expressions.h
+++ b/src/ast/traverse_expressions.h
@@ -24,7 +24,6 @@
#include "src/ast/literal_expression.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/phony_expression.h"
-#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h"
#include "src/utils/reverse.h"
@@ -113,8 +112,6 @@
// function name in the traversal.
// to_visit.push_back(call->func);
push_list(call->args);
- } else if (auto* type_ctor = expr->As<TypeConstructorExpression>()) {
- push_list(type_ctor->values);
} else if (auto* member = expr->As<MemberAccessorExpression>()) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
// member name in the traversal.
diff --git a/src/ast/traverse_expressions_test.cc b/src/ast/traverse_expressions_test.cc
index ecf9416..f5f3324 100644
--- a/src/ast/traverse_expressions_test.cc
+++ b/src/ast/traverse_expressions_test.cc
@@ -124,31 +124,6 @@
}
}
-TEST_F(TraverseExpressionsTest, DescendTypeConstructorExpression) {
- std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
- std::vector<const ast::Expression*> c = {vec2<i32>(e[0], e[1]),
- vec2<i32>(e[2], e[3])};
- auto* root = vec2<i32>(c[0], c[1]);
- {
- std::vector<const ast::Expression*> l2r;
- TraverseExpressions<TraverseOrder::LeftToRight>(
- root, Diagnostics(), [&](const ast::Expression* expr) {
- l2r.push_back(expr);
- return ast::TraverseAction::Descend;
- });
- EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
- }
- {
- std::vector<const ast::Expression*> r2l;
- TraverseExpressions<TraverseOrder::RightToLeft>(
- root, Diagnostics(), [&](const ast::Expression* expr) {
- r2l.push_back(expr);
- return ast::TraverseAction::Descend;
- });
- EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
- }
-}
-
// TODO(crbug.com/tint/1257): Test ignores member accessor 'member' field.
// Replace with the test below when fixed.
TEST_F(TraverseExpressionsTest, DescendMemberIndexExpression) {
diff --git a/src/ast/type_constructor_expression.cc b/src/ast/type_constructor_expression.cc
deleted file mode 100644
index 2745f2e..0000000
--- a/src/ast/type_constructor_expression.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/ast/type_constructor_expression.h"
-
-#include "src/program_builder.h"
-
-TINT_INSTANTIATE_TYPEINFO(tint::ast::TypeConstructorExpression);
-
-namespace tint {
-namespace ast {
-
-TypeConstructorExpression::TypeConstructorExpression(ProgramID pid,
- const Source& src,
- const ast::Type* ty,
- ExpressionList vals)
- : Base(pid, src), type(ty), values(std::move(vals)) {
- TINT_ASSERT(AST, type);
- for (auto* val : values) {
- TINT_ASSERT(AST, val);
- TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, val, program_id);
- }
-}
-
-TypeConstructorExpression::TypeConstructorExpression(
- TypeConstructorExpression&&) = default;
-
-TypeConstructorExpression::~TypeConstructorExpression() = default;
-
-const TypeConstructorExpression* TypeConstructorExpression::Clone(
- CloneContext* ctx) const {
- // Clone arguments outside of create() call to have deterministic ordering
- auto src = ctx->Clone(source);
- auto* ty = ctx->Clone(type);
- auto vals = ctx->Clone(values);
- return ctx->dst->create<TypeConstructorExpression>(src, ty, vals);
-}
-
-} // namespace ast
-} // namespace tint
diff --git a/src/ast/type_constructor_expression.h b/src/ast/type_constructor_expression.h
deleted file mode 100644
index d6ed7b5..0000000
--- a/src/ast/type_constructor_expression.h
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef SRC_AST_TYPE_CONSTRUCTOR_EXPRESSION_H_
-#define SRC_AST_TYPE_CONSTRUCTOR_EXPRESSION_H_
-
-#include <utility>
-
-#include "src/ast/expression.h"
-
-namespace tint {
-namespace ast {
-
-// Forward declaration
-class Type;
-
-/// A type specific constructor
-class TypeConstructorExpression
- : public Castable<TypeConstructorExpression, Expression> {
- public:
- /// Constructor
- /// @param pid the identifier of the program that owns this node
- /// @param src the source of this node
- /// @param type the type
- /// @param values the constructor values
- TypeConstructorExpression(ProgramID pid,
- const Source& src,
- const ast::Type* type,
- ExpressionList values);
- /// Move constructor
- TypeConstructorExpression(TypeConstructorExpression&&);
- ~TypeConstructorExpression() override;
-
- /// Clones this node and all transitive child nodes using the `CloneContext`
- /// `ctx`.
- /// @param ctx the clone context
- /// @return the newly cloned node
- const TypeConstructorExpression* Clone(CloneContext* ctx) const override;
-
- /// The type
- const ast::Type* const type;
-
- /// The values
- const ExpressionList values;
-};
-
-} // namespace ast
-} // namespace tint
-
-#endif // SRC_AST_TYPE_CONSTRUCTOR_EXPRESSION_H_
diff --git a/src/ast/type_constructor_expression_test.cc b/src/ast/type_constructor_expression_test.cc
deleted file mode 100644
index fa87cc5..0000000
--- a/src/ast/type_constructor_expression_test.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2020 The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "gtest/gtest-spi.h"
-#include "src/ast/test_helper.h"
-
-namespace tint {
-namespace ast {
-namespace {
-
-using TypeConstructorExpressionTest = TestHelper;
-
-TEST_F(TypeConstructorExpressionTest, Creation) {
- ExpressionList expr;
- expr.push_back(Expr("expr"));
-
- auto* t = create<TypeConstructorExpression>(ty.f32(), expr);
- EXPECT_TRUE(t->type->Is<ast::F32>());
- ASSERT_EQ(t->values.size(), 1u);
- EXPECT_EQ(t->values[0], expr[0]);
-}
-
-TEST_F(TypeConstructorExpressionTest, Creation_WithSource) {
- ExpressionList expr;
- expr.push_back(Expr("expr"));
-
- auto* t = create<TypeConstructorExpression>(Source{Source::Location{20, 2}},
- ty.f32(), expr);
- auto src = t->source;
- EXPECT_EQ(src.range.begin.line, 20u);
- EXPECT_EQ(src.range.begin.column, 2u);
-}
-
-TEST_F(TypeConstructorExpressionTest, IsTypeConstructor) {
- ExpressionList expr;
- expr.push_back(Expr("expr"));
-
- auto* t = create<TypeConstructorExpression>(ty.f32(), expr);
- EXPECT_TRUE(t->Is<TypeConstructorExpression>());
-}
-
-TEST_F(TypeConstructorExpressionTest, Assert_Null_Type) {
- EXPECT_FATAL_FAILURE(
- {
- ProgramBuilder b;
- b.create<TypeConstructorExpression>(nullptr, ExpressionList{b.Expr(1)});
- },
- "internal compiler error");
-}
-
-TEST_F(TypeConstructorExpressionTest, Assert_Null_Value) {
- EXPECT_FATAL_FAILURE(
- {
- ProgramBuilder b;
- b.create<TypeConstructorExpression>(b.ty.i32(),
- ExpressionList{nullptr});
- },
- "internal compiler error");
-}
-
-TEST_F(TypeConstructorExpressionTest, Assert_DifferentProgramID_Value) {
- EXPECT_FATAL_FAILURE(
- {
- ProgramBuilder b1;
- ProgramBuilder b2;
- b1.create<TypeConstructorExpression>(b1.ty.i32(),
- ExpressionList{b2.Expr(1)});
- },
- "internal compiler error");
-}
-
-} // namespace
-} // namespace ast
-} // namespace tint
diff --git a/src/program_builder.h b/src/program_builder.h
index 6d87945..ae6f3f5 100644
--- a/src/program_builder.h
+++ b/src/program_builder.h
@@ -64,7 +64,6 @@
#include "src/ast/struct_member_offset_decoration.h"
#include "src/ast/struct_member_size_decoration.h"
#include "src/ast/switch_statement.h"
-#include "src/ast/type_constructor_expression.h"
#include "src/ast/type_name.h"
#include "src/ast/u32.h"
#include "src/ast/uint_literal_expression.h"
@@ -1125,35 +1124,33 @@
ast::ExpressionList ExprList(ast::ExpressionList list) { return list; }
/// @param args the arguments for the type constructor
- /// @return an `ast::TypeConstructorExpression` of type `ty`, with the values
+ /// @return an `ast::CallExpression` of type `ty`, with the values
/// of `args` converted to `ast::Expression`s using `Expr()`
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* Construct(ARGS&&... args) {
+ const ast::CallExpression* Construct(ARGS&&... args) {
return Construct(ty.Of<T>(), std::forward<ARGS>(args)...);
}
/// @param type the type to construct
/// @param args the arguments for the constructor
- /// @return an `ast::TypeConstructorExpression` of `type` constructed with the
+ /// @return an `ast::CallExpression` of `type` constructed with the
/// values `args`.
template <typename... ARGS>
- const ast::TypeConstructorExpression* Construct(const ast::Type* type,
- ARGS&&... args) {
- return create<ast::TypeConstructorExpression>(
- type, ExprList(std::forward<ARGS>(args)...));
+ const ast::CallExpression* Construct(const ast::Type* type, ARGS&&... args) {
+ return Construct(source_, type, std::forward<ARGS>(args)...);
}
/// @param source the source information
/// @param type the type to construct
/// @param args the arguments for the constructor
- /// @return an `ast::TypeConstructorExpression` of `type` constructed with the
+ /// @return an `ast::CallExpression` of `type` constructed with the
/// values `args`.
template <typename... ARGS>
- const ast::TypeConstructorExpression* Construct(const Source& source,
- const ast::Type* type,
- ARGS&&... args) {
- return create<ast::TypeConstructorExpression>(
- source, type, ExprList(std::forward<ARGS>(args)...));
+ const ast::CallExpression* Construct(const Source& source,
+ const ast::Type* type,
+ ARGS&&... args) {
+ return create<ast::CallExpression>(source, type,
+ ExprList(std::forward<ARGS>(args)...));
}
/// @param expr the expression for the bitcast
@@ -1189,128 +1186,128 @@
/// @param args the arguments for the vector constructor
/// @param type the vector type
/// @param size the vector size
- /// @return an `ast::TypeConstructorExpression` of a `size`-element vector of
+ /// @return an `ast::CallExpression` of a `size`-element vector of
/// type `type`, constructed with the values `args`.
template <typename... ARGS>
- const ast::TypeConstructorExpression* vec(const ast::Type* type,
- uint32_t size,
- ARGS&&... args) {
+ const ast::CallExpression* vec(const ast::Type* type,
+ uint32_t size,
+ ARGS&&... args) {
return Construct(ty.vec(type, size), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the vector constructor
- /// @return an `ast::TypeConstructorExpression` of a 2-element vector of type
+ /// @return an `ast::CallExpression` of a 2-element vector of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* vec2(ARGS&&... args) {
+ const ast::CallExpression* vec2(ARGS&&... args) {
return Construct(ty.vec2<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the vector constructor
- /// @return an `ast::TypeConstructorExpression` of a 3-element vector of type
+ /// @return an `ast::CallExpression` of a 3-element vector of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* vec3(ARGS&&... args) {
+ const ast::CallExpression* vec3(ARGS&&... args) {
return Construct(ty.vec3<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the vector constructor
- /// @return an `ast::TypeConstructorExpression` of a 4-element vector of type
+ /// @return an `ast::CallExpression` of a 4-element vector of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* vec4(ARGS&&... args) {
+ const ast::CallExpression* vec4(ARGS&&... args) {
return Construct(ty.vec4<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 2x2 matrix of type
+ /// @return an `ast::CallExpression` of a 2x2 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat2x2(ARGS&&... args) {
+ const ast::CallExpression* mat2x2(ARGS&&... args) {
return Construct(ty.mat2x2<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 2x3 matrix of type
+ /// @return an `ast::CallExpression` of a 2x3 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat2x3(ARGS&&... args) {
+ const ast::CallExpression* mat2x3(ARGS&&... args) {
return Construct(ty.mat2x3<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 2x4 matrix of type
+ /// @return an `ast::CallExpression` of a 2x4 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat2x4(ARGS&&... args) {
+ const ast::CallExpression* mat2x4(ARGS&&... args) {
return Construct(ty.mat2x4<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 3x2 matrix of type
+ /// @return an `ast::CallExpression` of a 3x2 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat3x2(ARGS&&... args) {
+ const ast::CallExpression* mat3x2(ARGS&&... args) {
return Construct(ty.mat3x2<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 3x3 matrix of type
+ /// @return an `ast::CallExpression` of a 3x3 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat3x3(ARGS&&... args) {
+ const ast::CallExpression* mat3x3(ARGS&&... args) {
return Construct(ty.mat3x3<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 3x4 matrix of type
+ /// @return an `ast::CallExpression` of a 3x4 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat3x4(ARGS&&... args) {
+ const ast::CallExpression* mat3x4(ARGS&&... args) {
return Construct(ty.mat3x4<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 4x2 matrix of type
+ /// @return an `ast::CallExpression` of a 4x2 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat4x2(ARGS&&... args) {
+ const ast::CallExpression* mat4x2(ARGS&&... args) {
return Construct(ty.mat4x2<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 4x3 matrix of type
+ /// @return an `ast::CallExpression` of a 4x3 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat4x3(ARGS&&... args) {
+ const ast::CallExpression* mat4x3(ARGS&&... args) {
return Construct(ty.mat4x3<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the matrix constructor
- /// @return an `ast::TypeConstructorExpression` of a 4x4 matrix of type
+ /// @return an `ast::CallExpression` of a 4x4 matrix of type
/// `T`, constructed with the values `args`.
template <typename T, typename... ARGS>
- const ast::TypeConstructorExpression* mat4x4(ARGS&&... args) {
+ const ast::CallExpression* mat4x4(ARGS&&... args) {
return Construct(ty.mat4x4<T>(), std::forward<ARGS>(args)...);
}
/// @param args the arguments for the array constructor
- /// @return an `ast::TypeConstructorExpression` of an array with element type
+ /// @return an `ast::CallExpression` of an array with element type
/// `T` and size `N`, constructed with the values `args`.
template <typename T, int N, typename... ARGS>
- const ast::TypeConstructorExpression* array(ARGS&&... args) {
+ const ast::CallExpression* array(ARGS&&... args) {
return Construct(ty.array<T, N>(), std::forward<ARGS>(args)...);
}
/// @param subtype the array element type
/// @param n the array size. nullptr represents a runtime-array.
/// @param args the arguments for the array constructor
- /// @return an `ast::TypeConstructorExpression` of an array with element type
+ /// @return an `ast::CallExpression` of an array with element type
/// `subtype`, constructed with the values `args`.
template <typename EXPR, typename... ARGS>
- const ast::TypeConstructorExpression* array(const ast::Type* subtype,
- EXPR&& n,
- ARGS&&... args) {
+ const ast::CallExpression* array(const ast::Type* subtype,
+ EXPR&& n,
+ ARGS&&... args) {
return Construct(ty.array(subtype, std::forward<EXPR>(n)),
std::forward<ARGS>(args)...);
}
diff --git a/src/reader/wgsl/parser_impl_call_stmt_test.cc b/src/reader/wgsl/parser_impl_call_stmt_test.cc
index ed49d4b..83d31ee 100644
--- a/src/reader/wgsl/parser_impl_call_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_call_stmt_test.cc
@@ -36,7 +36,7 @@
ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr;
- EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a"));
+ EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 0u);
}
@@ -52,7 +52,7 @@
ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr;
- EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a"));
+ EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 3u);
EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>());
@@ -71,7 +71,7 @@
ASSERT_TRUE(e->Is<ast::CallStatement>());
auto* c = e->As<ast::CallStatement>()->expr;
- EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a"));
+ EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 2u);
EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>());
diff --git a/src/reader/wgsl/parser_impl_const_expr_test.cc b/src/reader/wgsl/parser_impl_const_expr_test.cc
index f5958d1..e31763f 100644
--- a/src/reader/wgsl/parser_impl_const_expr_test.cc
+++ b/src/reader/wgsl/parser_impl_const_expr_test.cc
@@ -24,20 +24,19 @@
auto e = p->expect_const_expr();
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
- auto* t = e->As<ast::TypeConstructorExpression>();
- ASSERT_TRUE(t->type->Is<ast::Vector>());
- EXPECT_EQ(t->type->As<ast::Vector>()->width, 2u);
+ auto* t = e->As<ast::CallExpression>();
+ ASSERT_TRUE(t->target.type->Is<ast::Vector>());
+ EXPECT_EQ(t->target.type->As<ast::Vector>()->width, 2u);
- ASSERT_EQ(t->values.size(), 2u);
- auto& v = t->values;
+ ASSERT_EQ(t->args.size(), 2u);
- ASSERT_TRUE(v[0]->Is<ast::FloatLiteralExpression>());
- EXPECT_FLOAT_EQ(v[0]->As<ast::FloatLiteralExpression>()->value, 1.);
+ ASSERT_TRUE(t->args[0]->Is<ast::FloatLiteralExpression>());
+ EXPECT_FLOAT_EQ(t->args[0]->As<ast::FloatLiteralExpression>()->value, 1.);
- ASSERT_TRUE(v[1]->Is<ast::FloatLiteralExpression>());
- EXPECT_FLOAT_EQ(v[1]->As<ast::FloatLiteralExpression>()->value, 2.);
+ ASSERT_TRUE(t->args[1]->Is<ast::FloatLiteralExpression>());
+ EXPECT_FLOAT_EQ(t->args[1]->As<ast::FloatLiteralExpression>()->value, 2.);
}
TEST_F(ParserImplTest, ConstExpr_TypeDecl_Empty) {
@@ -45,13 +44,13 @@
auto e = p->expect_const_expr();
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
- auto* t = e->As<ast::TypeConstructorExpression>();
- ASSERT_TRUE(t->type->Is<ast::Vector>());
- EXPECT_EQ(t->type->As<ast::Vector>()->width, 2u);
+ auto* t = e->As<ast::CallExpression>();
+ ASSERT_TRUE(t->target.type->Is<ast::Vector>());
+ EXPECT_EQ(t->target.type->As<ast::Vector>()->width, 2u);
- ASSERT_EQ(t->values.size(), 0u);
+ ASSERT_EQ(t->args.size(), 0u);
}
TEST_F(ParserImplTest, ConstExpr_TypeDecl_TrailingComma) {
@@ -59,15 +58,15 @@
auto e = p->expect_const_expr();
ASSERT_FALSE(p->has_error()) << p->error();
ASSERT_FALSE(e.errored);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
- auto* t = e->As<ast::TypeConstructorExpression>();
- ASSERT_TRUE(t->type->Is<ast::Vector>());
- EXPECT_EQ(t->type->As<ast::Vector>()->width, 2u);
+ auto* t = e->As<ast::CallExpression>();
+ ASSERT_TRUE(t->target.type->Is<ast::Vector>());
+ EXPECT_EQ(t->target.type->As<ast::Vector>()->width, 2u);
- ASSERT_EQ(t->values.size(), 2u);
- ASSERT_TRUE(t->values[0]->Is<ast::LiteralExpression>());
- ASSERT_TRUE(t->values[1]->Is<ast::LiteralExpression>());
+ ASSERT_EQ(t->args.size(), 2u);
+ ASSERT_TRUE(t->args[0]->Is<ast::LiteralExpression>());
+ ASSERT_TRUE(t->args[1]->Is<ast::LiteralExpression>());
}
TEST_F(ParserImplTest, ConstExpr_TypeDecl_MissingRightParen) {
@@ -134,7 +133,7 @@
auto e = p->expect_const_expr();
ASSERT_FALSE(e.errored);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
}
TEST_F(ParserImplTest, ConstExpr_NotRegisteredType) {
diff --git a/src/reader/wgsl/parser_impl_primary_expression_test.cc b/src/reader/wgsl/parser_impl_primary_expression_test.cc
index 66312fd..0b5eac8 100644
--- a/src/reader/wgsl/parser_impl_primary_expression_test.cc
+++ b/src/reader/wgsl/parser_impl_primary_expression_test.cc
@@ -39,11 +39,13 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
- auto* ty = e->As<ast::TypeConstructorExpression>();
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
+ auto* call = e->As<ast::CallExpression>();
- ASSERT_EQ(ty->values.size(), 4u);
- const auto& val = ty->values;
+ EXPECT_NE(call->target.type, nullptr);
+
+ ASSERT_EQ(call->args.size(), 4u);
+ const auto& val = call->args;
ASSERT_TRUE(val[0]->Is<ast::SintLiteralExpression>());
EXPECT_EQ(val[0]->As<ast::SintLiteralExpression>()->value, 1);
@@ -64,10 +66,11 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
- auto* ty = e->As<ast::TypeConstructorExpression>();
- ASSERT_EQ(ty->values.size(), 0u);
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
+ auto* call = e->As<ast::CallExpression>();
+
+ ASSERT_EQ(call->args.size(), 0u);
}
TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_InvalidTypeDecl) {
@@ -124,15 +127,15 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
- auto* constructor = e->As<ast::TypeConstructorExpression>();
- ASSERT_TRUE(constructor->type->Is<ast::TypeName>());
- EXPECT_EQ(constructor->type->As<ast::TypeName>()->name,
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
+ auto* call = e->As<ast::CallExpression>();
+
+ ASSERT_TRUE(call->target.type->Is<ast::TypeName>());
+ EXPECT_EQ(call->target.type->As<ast::TypeName>()->name,
p->builder().Symbols().Get("S"));
- auto values = constructor->values;
- ASSERT_EQ(values.size(), 0u);
+ ASSERT_EQ(call->args.size(), 0u);
}
TEST_F(ParserImplTest, PrimaryExpression_TypeDecl_StructConstructor_NotEmpty) {
@@ -149,21 +152,21 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
- auto* constructor = e->As<ast::TypeConstructorExpression>();
- ASSERT_TRUE(constructor->type->Is<ast::TypeName>());
- EXPECT_EQ(constructor->type->As<ast::TypeName>()->name,
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
+ auto* call = e->As<ast::CallExpression>();
+
+ ASSERT_TRUE(call->target.type->Is<ast::TypeName>());
+ EXPECT_EQ(call->target.type->As<ast::TypeName>()->name,
p->builder().Symbols().Get("S"));
- auto values = constructor->values;
- ASSERT_EQ(values.size(), 2u);
+ ASSERT_EQ(call->args.size(), 2u);
- ASSERT_TRUE(values[0]->Is<ast::UintLiteralExpression>());
- EXPECT_EQ(values[0]->As<ast::UintLiteralExpression>()->value, 1u);
+ ASSERT_TRUE(call->args[0]->Is<ast::UintLiteralExpression>());
+ EXPECT_EQ(call->args[0]->As<ast::UintLiteralExpression>()->value, 1u);
- ASSERT_TRUE(values[1]->Is<ast::FloatLiteralExpression>());
- EXPECT_EQ(values[1]->As<ast::FloatLiteralExpression>()->value, 2.f);
+ ASSERT_TRUE(call->args[1]->Is<ast::FloatLiteralExpression>());
+ EXPECT_EQ(call->args[1]->As<ast::FloatLiteralExpression>()->value, 2.f);
}
TEST_F(ParserImplTest, PrimaryExpression_ConstLiteral_True) {
@@ -225,13 +228,14 @@
EXPECT_FALSE(e.errored);
EXPECT_FALSE(p->has_error()) << p->error();
ASSERT_NE(e.value, nullptr);
- ASSERT_TRUE(e->Is<ast::TypeConstructorExpression>());
- auto* c = e->As<ast::TypeConstructorExpression>();
- ASSERT_TRUE(c->type->Is<ast::F32>());
- ASSERT_EQ(c->values.size(), 1u);
+ ASSERT_TRUE(e->Is<ast::CallExpression>());
+ auto* call = e->As<ast::CallExpression>();
- ASSERT_TRUE(c->values[0]->Is<ast::IntLiteralExpression>());
+ ASSERT_TRUE(call->target.type->Is<ast::F32>());
+ ASSERT_EQ(call->args.size(), 1u);
+
+ ASSERT_TRUE(call->args[0]->Is<ast::IntLiteralExpression>());
}
TEST_F(ParserImplTest, PrimaryExpression_Bitcast) {
diff --git a/src/reader/wgsl/parser_impl_singular_expression_test.cc b/src/reader/wgsl/parser_impl_singular_expression_test.cc
index b0d1f45..4b53da5 100644
--- a/src/reader/wgsl/parser_impl_singular_expression_test.cc
+++ b/src/reader/wgsl/parser_impl_singular_expression_test.cc
@@ -97,7 +97,7 @@
ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* c = e->As<ast::CallExpression>();
- EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("a"));
+ EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(c->args.size(), 0u);
}
@@ -113,7 +113,7 @@
ASSERT_TRUE(e->Is<ast::CallExpression>());
auto* c = e->As<ast::CallExpression>();
- EXPECT_EQ(c->func->symbol, p->builder().Symbols().Get("test"));
+ EXPECT_EQ(c->target.name->symbol, p->builder().Symbols().Get("test"));
EXPECT_EQ(c->args.size(), 3u);
EXPECT_TRUE(c->args[0]->Is<ast::IntLiteralExpression>());
diff --git a/src/reader/wgsl/parser_impl_variable_stmt_test.cc b/src/reader/wgsl/parser_impl_variable_stmt_test.cc
index c248e7a..1d5748a 100644
--- a/src/reader/wgsl/parser_impl_variable_stmt_test.cc
+++ b/src/reader/wgsl/parser_impl_variable_stmt_test.cc
@@ -90,7 +90,10 @@
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr);
- EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>());
+ auto* call = e->variable->constructor->As<ast::CallExpression>();
+ ASSERT_NE(call, nullptr);
+ EXPECT_EQ(call->target.name, nullptr);
+ EXPECT_NE(call->target.type, nullptr);
}
TEST_F(ParserImplTest, VariableStmt_VariableDecl_ArrayInit_NoSpace) {
@@ -105,7 +108,10 @@
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr);
- EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>());
+ auto* call = e->variable->constructor->As<ast::CallExpression>();
+ ASSERT_NE(call, nullptr);
+ EXPECT_EQ(call->target.name, nullptr);
+ EXPECT_NE(call->target.type, nullptr);
}
TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit) {
@@ -120,7 +126,10 @@
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr);
- EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>());
+ auto* call = e->variable->constructor->As<ast::CallExpression>();
+ ASSERT_NE(call, nullptr);
+ EXPECT_EQ(call->target.name, nullptr);
+ EXPECT_NE(call->target.type, nullptr);
}
TEST_F(ParserImplTest, VariableStmt_VariableDecl_VecInit_NoSpace) {
@@ -135,7 +144,10 @@
EXPECT_EQ(e->variable->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->variable->constructor, nullptr);
- EXPECT_TRUE(e->variable->constructor->Is<ast::TypeConstructorExpression>());
+ auto* call = e->variable->constructor->As<ast::CallExpression>();
+ ASSERT_NE(call, nullptr);
+ EXPECT_EQ(call->target.name, nullptr);
+ EXPECT_NE(call->target.type, nullptr);
}
TEST_F(ParserImplTest, VariableStmt_Let) {
diff --git a/src/resolver/call_test.cc b/src/resolver/call_test.cc
index 585099a..3448726 100644
--- a/src/resolver/call_test.cc
+++ b/src/resolver/call_test.cc
@@ -90,11 +90,15 @@
args.push_back(p.create_value(*this, 0));
}
- Func("foo", std::move(params), ty.f32(), {Return(1.23f)});
- auto* call = Call("foo", std::move(args));
- WrapInFunction(call);
+ auto* func = Func("foo", std::move(params), ty.f32(), {Return(1.23f)});
+ auto* call_expr = Call("foo", std::move(args));
+ WrapInFunction(call_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(call_expr);
+ EXPECT_NE(call, nullptr);
+ EXPECT_EQ(call->Target(), Sem().Get(func));
}
} // namespace
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 51b6e4c..323dacd 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -70,12 +70,15 @@
#include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h"
#include "src/sem/switch_statement.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
#include "src/sem/variable.h"
#include "src/utils/defer.h"
#include "src/utils/get_or_create.h"
#include "src/utils/math.h"
#include "src/utils/reverse.h"
#include "src/utils/scoped_assignment.h"
+#include "src/utils/transform.h"
namespace tint {
namespace resolver {
@@ -510,8 +513,8 @@
builder_->create<sem::Reference>(storage_ty, storage_class, access);
}
- if (rhs && !ValidateVariableConstructor(var, storage_class, storage_ty,
- rhs->Type())) {
+ if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty,
+ rhs->Type())) {
return nullptr;
}
@@ -641,10 +644,11 @@
}
}
-bool Resolver::ValidateVariableConstructor(const ast::Variable* var,
- ast::StorageClass storage_class,
- const sem::Type* storage_ty,
- const sem::Type* rhs_ty) {
+bool Resolver::ValidateVariableConstructorOrCast(
+ const ast::Variable* var,
+ ast::StorageClass storage_class,
+ const sem::Type* storage_ty,
+ const sem::Type* rhs_ty) {
auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type
@@ -2369,8 +2373,6 @@
sem_expr = Bitcast(bitcast);
} else if (auto* call = expr->As<ast::CallExpression>()) {
sem_expr = Call(call);
- } else if (auto* ctor = expr->As<ast::TypeConstructorExpression>()) {
- sem_expr = TypeConstructor(ctor);
} else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
sem_expr = Identifier(ident);
} else if (auto* literal = expr->As<ast::LiteralExpression>()) {
@@ -2462,33 +2464,72 @@
return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
}
-sem::Expression* Resolver::Call(const ast::CallExpression* expr) {
- auto* ident = expr->func;
- Mark(ident);
- auto name = builder_->Symbols().NameFor(ident->symbol);
-
- auto intrinsic_type = sem::ParseIntrinsicType(name);
- auto* call = (intrinsic_type != IntrinsicType::kNone)
- ? IntrinsicCall(expr, intrinsic_type)
- : FunctionCall(expr);
-
- current_function_->AddDirectCall(call);
- return call;
-}
-
-sem::Call* Resolver::IntrinsicCall(const ast::CallExpression* expr,
- sem::IntrinsicType intrinsic_type) {
+sem::Call* Resolver::Call(const ast::CallExpression* expr) {
std::vector<const sem::Expression*> args(expr->args.size());
- std::vector<const sem::Type*> arg_tys(expr->args.size());
+ std::vector<const sem::Type*> arg_tys(args.size());
for (size_t i = 0; i < expr->args.size(); i++) {
auto* arg = Sem(expr->args[i]);
if (!arg) {
return nullptr;
}
args[i] = arg;
- arg_tys[i] = arg->Type();
+ arg_tys[i] = args[i]->Type();
}
+ auto type_ctor_or_conv = [&](const sem::Type* ty) -> sem::Call* {
+ // The call has resolved to a type constructor or cast.
+ if (args.size() == 1) {
+ auto* target = ty;
+ auto* source = args[0]->Type()->UnwrapRef();
+ if ((source != target) && //
+ ((source->is_scalar() && target->is_scalar()) ||
+ (source->Is<sem::Vector>() && target->Is<sem::Vector>()) ||
+ (source->Is<sem::Matrix>() && target->Is<sem::Matrix>()))) {
+ // Note: Matrix types currently cannot be converted (the element type
+ // must only be f32). We implement this for the day we support other
+ // matrix element types.
+ return TypeConversion(expr, ty, args[0], arg_tys[0]);
+ }
+ }
+ return TypeConstructor(expr, ty, std::move(args), std::move(arg_tys));
+ };
+
+ // Resolve the target of the CallExpression to determine whether this is a
+ // function call, cast or type constructor expression.
+ if (expr->target.type) {
+ auto* ty = Type(expr->target.type);
+ if (!ty) {
+ return nullptr;
+ }
+ return type_ctor_or_conv(ty);
+ }
+
+ auto* ident = expr->target.name;
+ Mark(ident);
+
+ auto it = named_type_info_.find(ident->symbol);
+ if (it != named_type_info_.end()) {
+ // We have a type.
+ return type_ctor_or_conv(it->second.sem);
+ }
+
+ // Not a type, treat as a intrinsic / function call.
+ auto name = builder_->Symbols().NameFor(ident->symbol);
+ auto intrinsic_type = sem::ParseIntrinsicType(name);
+ auto* call = (intrinsic_type != IntrinsicType::kNone)
+ ? IntrinsicCall(expr, intrinsic_type, std::move(args),
+ std::move(arg_tys))
+ : FunctionCall(expr, std::move(args));
+
+ current_function_->AddDirectCall(call);
+ return call;
+}
+
+sem::Call* Resolver::IntrinsicCall(
+ const ast::CallExpression* expr,
+ sem::IntrinsicType intrinsic_type,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys) {
auto* intrinsic = intrinsic_table_->Lookup(intrinsic_type, std::move(arg_tys),
expr->source);
if (!intrinsic) {
@@ -2509,21 +2550,45 @@
return nullptr;
}
- if (!ValidateCall(call)) {
+ if (!ValidateIntrinsicCall(call)) {
return nullptr;
}
return call;
}
-sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr) {
- auto* ident = expr->func;
- auto name = builder_->Symbols().NameFor(ident->symbol);
+bool Resolver::ValidateIntrinsicCall(const sem::Call* call) {
+ if (call->Type()->Is<sem::Void>()) {
+ bool is_call_statement = false;
+ if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
+ if (call_stmt->expr == call->Declaration()) {
+ is_call_statement = true;
+ }
+ }
+ if (!is_call_statement) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
+ // If the called function does not return a value, a function call
+ // statement should be used instead.
+ auto* ident = call->Declaration()->target.name;
+ auto name = builder_->Symbols().NameFor(ident->symbol);
+ AddError("intrinsic '" + name + "' does not return a value",
+ call->Declaration()->source);
+ return false;
+ }
+ }
- auto target_it = symbol_to_function_.find(ident->symbol);
+ return true;
+}
+
+sem::Call* Resolver::FunctionCall(
+ const ast::CallExpression* expr,
+ const std::vector<const sem::Expression*> args) {
+ auto sym = expr->target.name->symbol;
+ auto name = builder_->Symbols().NameFor(sym);
+
+ auto target_it = symbol_to_function_.find(sym);
if (target_it == symbol_to_function_.end()) {
- if (current_function_ &&
- current_function_->Declaration()->symbol == ident->symbol) {
+ if (current_function_ && current_function_->Declaration()->symbol == sym) {
AddError("recursion is not permitted. '" + name +
"' attempted to call itself.",
expr->source);
@@ -2533,16 +2598,6 @@
return nullptr;
}
auto* target = target_it->second;
-
- std::vector<const sem::Expression*> args(expr->args.size());
- for (size_t i = 0; i < expr->args.size(); i++) {
- auto* arg = Sem(expr->args[i]);
- if (!arg) {
- return nullptr;
- }
- args[i] = arg;
- }
-
auto* call = builder_->create<sem::Call>(expr, target, std::move(args),
current_statement_, sem::Constant{});
@@ -2567,38 +2622,9 @@
return nullptr;
}
- if (!ValidateCall(call)) {
- return nullptr;
- }
-
return call;
}
-bool Resolver::ValidateCall(const sem::Call* call) {
- if (call->Type()->Is<sem::Void>()) {
- bool is_call_statement = false;
- if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
- if (call_stmt->expr == call->Declaration()) {
- is_call_statement = true;
- }
- }
- if (!is_call_statement) {
- // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
- // If the called function does not return a value, a function call
- // statement should be used instead.
- auto* ident = call->Declaration()->func;
- auto name = builder_->Symbols().NameFor(ident->symbol);
- bool is_function = call->Target()->Is<sem::Function>();
- AddError((is_function ? "function" : "intrinsic") + std::string(" '") +
- name + "' does not return a value",
- call->Declaration()->source);
- return false;
- }
- }
-
- return true;
-}
-
bool Resolver::ValidateTextureIntrinsicFunction(const sem::Call* call) {
auto* intrinsic = call->Target()->As<sem::Intrinsic>();
if (!intrinsic) {
@@ -2623,8 +2649,7 @@
bool is_const_expr = true;
ast::TraverseExpressions(
arg->Declaration(), diagnostics_, [&](const ast::Expression* e) {
- if (e->IsAnyOf<ast::LiteralExpression,
- ast::TypeConstructorExpression>()) {
+ if (e->IsAnyOf<ast::LiteralExpression, ast::CallExpression>()) {
return ast::TraverseAction::Descend;
}
is_const_expr = false;
@@ -2654,9 +2679,9 @@
bool Resolver::ValidateFunctionCall(const sem::Call* call) {
auto* decl = call->Declaration();
- auto* ident = decl->func;
auto* target = call->Target()->As<sem::Function>();
- auto name = builder_->Symbols().NameFor(ident->symbol);
+ auto sym = decl->target.name->symbol;
+ auto name = builder_->Symbols().NameFor(sym);
if (target->Declaration()->IsEntryPoint()) {
// https://www.w3.org/TR/WGSL/#function-restriction
@@ -2735,40 +2760,150 @@
}
}
}
+
+ if (call->Type()->Is<sem::Void>()) {
+ bool is_call_statement = false;
+ if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
+ if (call_stmt->expr == call->Declaration()) {
+ is_call_statement = true;
+ }
+ }
+ if (!is_call_statement) {
+ // https://gpuweb.github.io/gpuweb/wgsl/#function-call-expr
+ // If the called function does not return a value, a function call
+ // statement should be used instead.
+ AddError("function '" + name + "' does not return a value", decl->source);
+ return false;
+ }
+ }
return true;
}
-sem::Expression* Resolver::TypeConstructor(
- const ast::TypeConstructorExpression* expr) {
- auto* ty = Type(expr->type);
- if (!ty) {
+sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr,
+ const sem::Type* target,
+ const sem::Expression* arg,
+ const sem::Type* source) {
+ // It is not valid to have a type-cast call expression inside a call
+ // statement.
+ if (current_statement_) {
+ if (auto* stmt =
+ current_statement_->Declaration()->As<ast::CallStatement>()) {
+ if (stmt->expr == expr) {
+ AddError("type cast evaluated but not used", expr->source);
+ return nullptr;
+ }
+ }
+ }
+
+ auto* call_target = utils::GetOrCreate(
+ type_conversions_, TypeConversionSig{target, source},
+ [&]() -> sem::TypeConversion* {
+ // Now that the argument types have been determined, make sure that they
+ // obey the conversion rules laid out in
+ // https://gpuweb.github.io/gpuweb/wgsl/#conversion-expr.
+ bool ok = true;
+ if (auto* vec_type = target->As<sem::Vector>()) {
+ ok = ValidateVectorConstructorOrCast(expr, vec_type);
+ } else if (auto* mat_type = target->As<sem::Matrix>()) {
+ // Note: Matrix types currently cannot be converted (the element type
+ // must only be f32). We implement this for the day we support other
+ // matrix element types.
+ ok = ValidateMatrixConstructorOrCast(expr, mat_type);
+ } else if (target->is_scalar()) {
+ ok = ValidateScalarConstructorOrCast(expr, target);
+ } else if (auto* arr_type = target->As<sem::Array>()) {
+ ok = ValidateArrayConstructorOrCast(expr, arr_type);
+ } else if (auto* struct_type = target->As<sem::Struct>()) {
+ ok = ValidateStructureConstructorOrCast(expr, struct_type);
+ } else {
+ AddError("type is not constructible", expr->source);
+ return nullptr;
+ }
+ if (!ok) {
+ return nullptr;
+ }
+
+ auto* param = builder_->create<sem::Parameter>(
+ nullptr, // declaration
+ 0, // index
+ source->UnwrapRef(), // type
+ ast::StorageClass::kNone, // storage_class
+ ast::Access::kUndefined); // access
+ return builder_->create<sem::TypeConversion>(target, param);
+ });
+
+ if (!call_target) {
return nullptr;
}
- // Now that the argument types have been determined, make sure that they
- // obey the constructor type rules laid out in
- // https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
- bool ok = true;
- if (auto* vec_type = ty->As<sem::Vector>()) {
- ok = ValidateVectorConstructor(expr, vec_type);
- } else if (auto* mat_type = ty->As<sem::Matrix>()) {
- ok = ValidateMatrixConstructor(expr, mat_type);
- } else if (ty->is_scalar()) {
- ok = ValidateScalarConstructor(expr, ty);
- } else if (auto* arr_type = ty->As<sem::Array>()) {
- ok = ValidateArrayConstructor(expr, arr_type);
- } else if (auto* struct_type = ty->As<sem::Struct>()) {
- ok = ValidateStructureConstructor(expr, struct_type);
- } else {
- AddError("type is not constructible", expr->source);
- return nullptr;
+ auto val = EvaluateConstantValue(expr, target);
+ return builder_->create<sem::Call>(expr, call_target,
+ std::vector<const sem::Expression*>{arg},
+ current_statement_, val);
+}
+
+sem::Call* Resolver::TypeConstructor(
+ const ast::CallExpression* expr,
+ const sem::Type* ty,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys) {
+ // It is not valid to have a type-constructor call expression as a call
+ // statement.
+ if (current_statement_) {
+ if (auto* stmt =
+ current_statement_->Declaration()->As<ast::CallStatement>()) {
+ if (stmt->expr == expr) {
+ AddError("type constructor evaluated but not used", expr->source);
+ return nullptr;
+ }
+ }
}
- if (!ok) {
+
+ auto* call_target = utils::GetOrCreate(
+ type_ctors_, TypeConstructorSig{ty, arg_tys},
+ [&]() -> sem::TypeConstructor* {
+ // Now that the argument types have been determined, make sure that they
+ // obey the constructor type rules laid out in
+ // https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr.
+ bool ok = true;
+ if (auto* vec_type = ty->As<sem::Vector>()) {
+ ok = ValidateVectorConstructorOrCast(expr, vec_type);
+ } else if (auto* mat_type = ty->As<sem::Matrix>()) {
+ ok = ValidateMatrixConstructorOrCast(expr, mat_type);
+ } else if (ty->is_scalar()) {
+ ok = ValidateScalarConstructorOrCast(expr, ty);
+ } else if (auto* arr_type = ty->As<sem::Array>()) {
+ ok = ValidateArrayConstructorOrCast(expr, arr_type);
+ } else if (auto* struct_type = ty->As<sem::Struct>()) {
+ ok = ValidateStructureConstructorOrCast(expr, struct_type);
+ } else {
+ AddError("type is not constructible", expr->source);
+ return nullptr;
+ }
+ if (!ok) {
+ return nullptr;
+ }
+
+ return builder_->create<sem::TypeConstructor>(
+ ty, utils::Transform(
+ arg_tys,
+ [&](const sem::Type* t, size_t i) -> const sem::Parameter* {
+ return builder_->create<sem::Parameter>(
+ nullptr, // declaration
+ i, // index
+ t->UnwrapRef(), // type
+ ast::StorageClass::kNone, // storage_class
+ ast::Access::kUndefined); // access
+ }));
+ });
+
+ if (!call_target) {
return nullptr;
}
auto val = EvaluateConstantValue(expr, ty);
- return builder_->create<sem::Expression>(expr, ty, current_statement_, val);
+ return builder_->create<sem::Call>(expr, call_target, std::move(args),
+ current_statement_, val);
}
sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
@@ -2782,26 +2917,26 @@
val);
}
-bool Resolver::ValidateStructureConstructor(
- const ast::TypeConstructorExpression* ctor,
+bool Resolver::ValidateStructureConstructorOrCast(
+ const ast::CallExpression* ctor,
const sem::Struct* struct_type) {
if (!struct_type->IsConstructible()) {
AddError("struct constructor has non-constructible type", ctor->source);
return false;
}
- if (ctor->values.size() > 0) {
- if (ctor->values.size() != struct_type->Members().size()) {
+ if (ctor->args.size() > 0) {
+ if (ctor->args.size() != struct_type->Members().size()) {
std::string fm =
- ctor->values.size() < struct_type->Members().size() ? "few" : "many";
+ ctor->args.size() < struct_type->Members().size() ? "few" : "many";
AddError("struct constructor has too " + fm + " inputs: expected " +
std::to_string(struct_type->Members().size()) + ", found " +
- std::to_string(ctor->values.size()),
+ std::to_string(ctor->args.size()),
ctor->source);
return false;
}
for (auto* member : struct_type->Members()) {
- auto* value = ctor->values[member->Index()];
+ auto* value = ctor->args[member->Index()];
auto* value_ty = TypeOf(value);
if (member->Type() != value_ty->UnwrapRef()) {
AddError(
@@ -2817,10 +2952,9 @@
return true;
}
-bool Resolver::ValidateArrayConstructor(
- const ast::TypeConstructorExpression* ctor,
- const sem::Array* array_type) {
- auto& values = ctor->values;
+bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Array* array_type) {
+ auto& values = ctor->args;
auto* elem_ty = array_type->ElemType();
for (auto* value : values) {
auto* value_ty = TypeOf(value)->UnwrapRef();
@@ -2839,7 +2973,7 @@
return false;
} else if (!elem_ty->IsConstructible()) {
AddError("array constructor has non-constructible element type",
- ctor->type->As<ast::Array>()->type->source);
+ ctor->source);
return false;
} else if (!values.empty() && (values.size() != array_type->Count())) {
std::string fm = values.size() < array_type->Count() ? "few" : "many";
@@ -2858,10 +2992,9 @@
return true;
}
-bool Resolver::ValidateVectorConstructor(
- const ast::TypeConstructorExpression* ctor,
- const sem::Vector* vec_type) {
- auto& values = ctor->values;
+bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Vector* vec_type) {
+ auto& values = ctor->args;
auto* elem_ty = vec_type->type();
size_t value_cardinality_sum = 0;
for (auto* value : values) {
@@ -2937,10 +3070,9 @@
return true;
}
-bool Resolver::ValidateMatrixConstructor(
- const ast::TypeConstructorExpression* ctor,
- const sem::Matrix* matrix_ty) {
- auto& values = ctor->values;
+bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Matrix* matrix_ty) {
+ auto& values = ctor->args;
// Zero Value expression
if (values.empty()) {
return true;
@@ -3000,21 +3132,20 @@
return true;
}
-bool Resolver::ValidateScalarConstructor(
- const ast::TypeConstructorExpression* ctor,
- const sem::Type* ty) {
- if (ctor->values.size() == 0) {
+bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Type* ty) {
+ if (ctor->args.size() == 0) {
return true;
}
- if (ctor->values.size() > 1) {
+ if (ctor->args.size() > 1) {
AddError("expected zero or one value in constructor, got " +
- std::to_string(ctor->values.size()),
+ std::to_string(ctor->args.size()),
ctor->source);
return false;
}
// Validate constructor
- auto* value = ctor->values[0];
+ auto* value = ctor->args[0];
auto* value_ty = TypeOf(value)->UnwrapRef();
using Bool = sem::Bool;
@@ -4547,5 +4678,37 @@
return sem;
}
+////////////////////////////////////////////////////////////////////////////////
+// Resolver::TypeConversionSig
+////////////////////////////////////////////////////////////////////////////////
+bool Resolver::TypeConversionSig::operator==(
+ const TypeConversionSig& rhs) const {
+ return target == rhs.target && source == rhs.source;
+}
+std::size_t Resolver::TypeConversionSig::Hasher::operator()(
+ const TypeConversionSig& sig) const {
+ return utils::Hash(sig.target, sig.source);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Resolver::TypeConstructorSig
+////////////////////////////////////////////////////////////////////////////////
+Resolver::TypeConstructorSig::TypeConstructorSig(
+ const sem::Type* ty,
+ const std::vector<const sem::Type*> params)
+ : type(ty), parameters(params) {}
+Resolver::TypeConstructorSig::TypeConstructorSig(const TypeConstructorSig&) =
+ default;
+Resolver::TypeConstructorSig::~TypeConstructorSig() = default;
+
+bool Resolver::TypeConstructorSig::operator==(
+ const TypeConstructorSig& rhs) const {
+ return type == rhs.type && parameters == rhs.parameters;
+}
+std::size_t Resolver::TypeConstructorSig::Hasher::operator()(
+ const TypeConstructorSig& sig) const {
+ return utils::Hash(sig.type, sig.parameters);
+}
+
} // namespace resolver
} // namespace tint
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 6578ce1..a82dab7 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -59,6 +59,7 @@
class Atomic;
class Intrinsic;
class Statement;
+class TypeConstructor;
} // namespace sem
namespace resolver {
@@ -170,15 +171,26 @@
sem::Expression* IndexAccessor(const ast::IndexAccessorExpression*);
sem::Expression* Binary(const ast::BinaryExpression*);
sem::Expression* Bitcast(const ast::BitcastExpression*);
- sem::Expression* Call(const ast::CallExpression*);
+ sem::Call* Call(const ast::CallExpression*);
sem::Expression* Expression(const ast::Expression*);
sem::Function* Function(const ast::Function*);
- sem::Call* FunctionCall(const ast::CallExpression*);
+ sem::Call* FunctionCall(const ast::CallExpression*,
+ const std::vector<const sem::Expression*> args);
sem::Expression* Identifier(const ast::IdentifierExpression*);
- sem::Call* IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType);
+ sem::Call* IntrinsicCall(const ast::CallExpression*,
+ sem::IntrinsicType,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys);
sem::Expression* Literal(const ast::LiteralExpression*);
sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
- sem::Expression* TypeConstructor(const ast::TypeConstructorExpression*);
+ sem::Call* TypeConversion(const ast::CallExpression* expr,
+ const sem::Type* ty,
+ const sem::Expression* arg,
+ const sem::Type* arg_ty);
+ sem::Call* TypeConstructor(const ast::CallExpression* expr,
+ const sem::Type* ty,
+ const std::vector<const sem::Expression*> args,
+ const std::vector<const sem::Type*> arg_tys);
sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
// Statement resolving methods
@@ -211,13 +223,13 @@
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
- bool ValidateCall(const sem::Call* call);
bool ValidateEntryPoint(const sem::Function* func);
bool ValidateFunction(const sem::Function* func);
bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
+ bool ValidateIntrinsicCall(const sem::Call* call);
bool ValidateLocationDecoration(const ast::LocationDecoration* location,
const sem::Type* type,
std::unordered_set<uint32_t>& locations,
@@ -234,23 +246,23 @@
bool ValidateStatements(const ast::StatementList& stmts);
bool ValidateStorageTexture(const ast::StorageTexture* t);
bool ValidateStructure(const sem::Struct* str);
- bool ValidateStructureConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Struct* struct_type);
+ bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Struct* struct_type);
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const sem::Variable* var);
- bool ValidateVariableConstructor(const ast::Variable* var,
- ast::StorageClass storage_class,
- const sem::Type* storage_type,
- const sem::Type* rhs_type);
+ bool ValidateVariableConstructorOrCast(const ast::Variable* var,
+ ast::StorageClass storage_class,
+ const sem::Type* storage_type,
+ const sem::Type* rhs_type);
bool ValidateVector(const sem::Vector* ty, const Source& source);
- bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Vector* vec_type);
- bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Matrix* matrix_type);
- bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Type* type);
- bool ValidateArrayConstructor(const ast::TypeConstructorExpression* ctor,
- const sem::Array* arr_type);
+ bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Vector* vec_type);
+ bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Matrix* matrix_type);
+ bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Type* type);
+ bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
+ const sem::Array* arr_type);
bool ValidateTypeDecl(const ast::TypeDecl* named_type) const;
bool ValidateTextureIntrinsicFunction(const sem::Call* call);
bool ValidateNoDuplicateDecorations(const ast::DecorationList& decorations);
@@ -378,15 +390,46 @@
const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type);
- sem::Constant EvaluateConstantValue(
- const ast::TypeConstructorExpression* type_ctor,
- const sem::Type* type);
+ sem::Constant EvaluateConstantValue(const ast::CallExpression* call,
+ const sem::Type* type);
/// Sem is a helper for obtaining the semantic node for the given AST node.
template <typename SEM = sem::Info::InferFromAST,
typename AST_OR_TYPE = CastableBase>
const sem::Info::GetResultType<SEM, AST_OR_TYPE>* Sem(const AST_OR_TYPE* ast);
+ struct TypeConversionSig {
+ const sem::Type* target;
+ const sem::Type* source;
+
+ bool operator==(const TypeConversionSig&) const;
+
+ /// Hasher provides a hash function for the TypeConversionSig
+ struct Hasher {
+ /// @param sig the TypeConversionSig to create a hash for
+ /// @return the hash value
+ std::size_t operator()(const TypeConversionSig& sig) const;
+ };
+ };
+
+ struct TypeConstructorSig {
+ const sem::Type* type;
+ const std::vector<const sem::Type*> parameters;
+
+ TypeConstructorSig(const sem::Type* ty,
+ const std::vector<const sem::Type*> params);
+ TypeConstructorSig(const TypeConstructorSig&);
+ ~TypeConstructorSig();
+ bool operator==(const TypeConstructorSig&) const;
+
+ /// Hasher provides a hash function for the TypeConstructorSig
+ struct Hasher {
+ /// @param sig the TypeConstructorSig to create a hash for
+ /// @return the hash value
+ std::size_t operator()(const TypeConstructorSig& sig) const;
+ };
+ };
+
ProgramBuilder* const builder_;
diag::List& diagnostics_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
@@ -398,6 +441,14 @@
std::unordered_set<const ast::Node*> marked_;
std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
+ std::unordered_map<TypeConversionSig,
+ sem::CallTarget*,
+ TypeConversionSig::Hasher>
+ type_conversions_;
+ std::unordered_map<TypeConstructorSig,
+ sem::CallTarget*,
+ TypeConstructorSig::Hasher>
+ type_ctors_;
sem::Function* current_function_ = nullptr;
sem::Statement* current_statement_ = nullptr;
diff --git a/src/resolver/resolver_constants.cc b/src/resolver/resolver_constants.cc
index 4252e1b..d2ba745 100644
--- a/src/resolver/resolver_constants.cc
+++ b/src/resolver/resolver_constants.cc
@@ -15,6 +15,7 @@
#include "src/resolver/resolver.h"
#include "src/sem/constant.h"
+#include "src/sem/type_constructor.h"
#include "src/utils/get_or_create.h"
namespace tint {
@@ -32,7 +33,7 @@
if (auto* e = expr->As<ast::LiteralExpression>()) {
return EvaluateConstantValue(e, type);
}
- if (auto* e = expr->As<ast::TypeConstructorExpression>()) {
+ if (auto* e = expr->As<ast::CallExpression>()) {
return EvaluateConstantValue(e, type);
}
return {};
@@ -57,10 +58,8 @@
return {};
}
-sem::Constant Resolver::EvaluateConstantValue(
- const ast::TypeConstructorExpression* type_ctor,
- const sem::Type* type) {
- auto& ctor_values = type_ctor->values;
+sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
+ const sem::Type* type) {
auto* vec = type->As<sem::Vector>();
// For now, only fold scalars and vectors
@@ -72,7 +71,7 @@
int result_size = vec ? static_cast<int>(vec->Width()) : 1;
// For zero value init, return 0s
- if (ctor_values.empty()) {
+ if (call->args.empty()) {
if (elem_type->Is<sem::I32>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, 0));
}
@@ -90,12 +89,12 @@
// Build value for type_ctor from each child value by casting to
// type_ctor's type.
sem::Constant::Scalars elems;
- for (auto* cv : ctor_values) {
- auto* expr = builder_->Sem().Get(cv);
- if (!expr || !expr->ConstantValue()) {
+ for (auto* expr : call->args) {
+ auto* arg = builder_->Sem().Get(expr);
+ if (!arg || !arg->ConstantValue()) {
return {};
}
- auto cast = ConstantCast(expr->ConstantValue(), elem_type);
+ auto cast = ConstantCast(arg->ConstantValue(), elem_type);
elems.insert(elems.end(), cast.Elements().begin(), cast.Elements().end());
}
diff --git a/src/resolver/type_constructor_validation_test.cc b/src/resolver/type_constructor_validation_test.cc
index aa81cf6..2c83801 100644
--- a/src/resolver/type_constructor_validation_test.cc
+++ b/src/resolver/type_constructor_validation_test.cc
@@ -15,6 +15,8 @@
#include "gmock/gmock.h"
#include "src/resolver/resolver_test_helper.h"
#include "src/sem/reference_type.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
namespace tint {
namespace resolver {
@@ -223,68 +225,74 @@
} // namespace InferTypeTest
-namespace ConversionConstructorTest {
+namespace ConversionConstructTest {
+enum class Kind {
+ Construct,
+ Conversion,
+};
+
struct Params {
+ Kind kind;
builder::ast_type_func_ptr lhs_type;
builder::ast_type_func_ptr rhs_type;
builder::ast_expr_func_ptr rhs_value_expr;
};
template <typename LhsType, typename RhsType>
-constexpr Params ParamsFor() {
- return Params{DataType<LhsType>::AST, DataType<RhsType>::AST,
+constexpr Params ParamsFor(Kind kind) {
+ return Params{kind, DataType<LhsType>::AST, DataType<RhsType>::AST,
DataType<RhsType>::Expr};
}
static constexpr Params valid_cases[] = {
// Direct init (non-conversions)
- ParamsFor<bool, bool>(), //
- ParamsFor<i32, i32>(), //
- ParamsFor<u32, u32>(), //
- ParamsFor<f32, f32>(), //
- ParamsFor<vec3<bool>, vec3<bool>>(), //
- ParamsFor<vec3<i32>, vec3<i32>>(), //
- ParamsFor<vec3<u32>, vec3<u32>>(), //
- ParamsFor<vec3<f32>, vec3<f32>>(), //
+ ParamsFor<bool, bool>(Kind::Construct), //
+ ParamsFor<i32, i32>(Kind::Construct), //
+ ParamsFor<u32, u32>(Kind::Construct), //
+ ParamsFor<f32, f32>(Kind::Construct), //
+ ParamsFor<vec3<bool>, vec3<bool>>(Kind::Construct), //
+ ParamsFor<vec3<i32>, vec3<i32>>(Kind::Construct), //
+ ParamsFor<vec3<u32>, vec3<u32>>(Kind::Construct), //
+ ParamsFor<vec3<f32>, vec3<f32>>(Kind::Construct), //
// Splat
- ParamsFor<vec3<bool>, bool>(), //
- ParamsFor<vec3<i32>, i32>(), //
- ParamsFor<vec3<u32>, u32>(), //
- ParamsFor<vec3<f32>, f32>(), //
+ ParamsFor<vec3<bool>, bool>(Kind::Construct), //
+ ParamsFor<vec3<i32>, i32>(Kind::Construct), //
+ ParamsFor<vec3<u32>, u32>(Kind::Construct), //
+ ParamsFor<vec3<f32>, f32>(Kind::Construct), //
// Conversion
- ParamsFor<bool, u32>(), //
- ParamsFor<bool, i32>(), //
- ParamsFor<bool, f32>(), //
+ ParamsFor<bool, u32>(Kind::Conversion), //
+ ParamsFor<bool, i32>(Kind::Conversion), //
+ ParamsFor<bool, f32>(Kind::Conversion), //
- ParamsFor<i32, bool>(), //
- ParamsFor<i32, u32>(), //
- ParamsFor<i32, f32>(), //
+ ParamsFor<i32, bool>(Kind::Conversion), //
+ ParamsFor<i32, u32>(Kind::Conversion), //
+ ParamsFor<i32, f32>(Kind::Conversion), //
- ParamsFor<u32, bool>(), //
- ParamsFor<u32, i32>(), //
- ParamsFor<u32, f32>(), //
+ ParamsFor<u32, bool>(Kind::Conversion), //
+ ParamsFor<u32, i32>(Kind::Conversion), //
+ ParamsFor<u32, f32>(Kind::Conversion), //
- ParamsFor<f32, bool>(), //
- ParamsFor<f32, u32>(), //
- ParamsFor<f32, i32>(), //
+ ParamsFor<f32, bool>(Kind::Conversion), //
+ ParamsFor<f32, u32>(Kind::Conversion), //
+ ParamsFor<f32, i32>(Kind::Conversion), //
- ParamsFor<vec3<bool>, vec3<u32>>(), //
- ParamsFor<vec3<bool>, vec3<i32>>(), //
- ParamsFor<vec3<bool>, vec3<f32>>(), //
+ ParamsFor<vec3<bool>, vec3<u32>>(Kind::Conversion), //
+ ParamsFor<vec3<bool>, vec3<i32>>(Kind::Conversion), //
+ ParamsFor<vec3<bool>, vec3<f32>>(Kind::Conversion), //
- ParamsFor<vec3<i32>, vec3<bool>>(), //
- ParamsFor<vec3<i32>, vec3<u32>>(), //
- ParamsFor<vec3<i32>, vec3<f32>>(), //
+ ParamsFor<vec3<i32>, vec3<bool>>(Kind::Conversion), //
+ ParamsFor<vec3<i32>, vec3<u32>>(Kind::Conversion), //
+ ParamsFor<vec3<i32>, vec3<f32>>(Kind::Conversion), //
- ParamsFor<vec3<u32>, vec3<bool>>(), //
- ParamsFor<vec3<u32>, vec3<i32>>(), //
- ParamsFor<vec3<u32>, vec3<f32>>(), //
+ ParamsFor<vec3<u32>, vec3<bool>>(Kind::Conversion), //
+ ParamsFor<vec3<u32>, vec3<i32>>(Kind::Conversion), //
+ ParamsFor<vec3<u32>, vec3<f32>>(Kind::Conversion), //
- ParamsFor<vec3<f32>, vec3<bool>>(), //
- ParamsFor<vec3<f32>, vec3<u32>>(), //
- ParamsFor<vec3<f32>, vec3<i32>>(), //
+ ParamsFor<vec3<f32>, vec3<bool>>(Kind::Conversion), //
+ ParamsFor<vec3<f32>, vec3<u32>>(Kind::Conversion), //
+ ParamsFor<vec3<f32>, vec3<i32>>(Kind::Conversion), //
};
using ConversionConstructorValidTest = ResolverTestWithParam<Params>;
@@ -302,8 +310,9 @@
<< FriendlyName(rhs_type) << "(<rhs value expr>))";
SCOPED_TRACE(ss.str());
- auto* a = Var("a", lhs_type1, ast::StorageClass::kNone,
- Construct(lhs_type2, Construct(rhs_type, rhs_value_expr)));
+ auto* arg = Construct(rhs_type, rhs_value_expr);
+ auto* tc = Construct(lhs_type2, arg);
+ auto* a = Var("a", lhs_type1, ast::StorageClass::kNone, tc);
// Self-assign 'a' to force the expression to be resolved so we can test its
// type below
@@ -311,6 +320,27 @@
WrapInFunction(Decl(a), Assign(a_ident, "a"));
ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ switch (params.kind) {
+ case Kind::Construct: {
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_EQ(ctor->Parameters()[0]->Type(), TypeOf(arg));
+ break;
+ }
+ case Kind::Conversion: {
+ auto* conv = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(conv, nullptr);
+ EXPECT_EQ(call->Type(), conv->ReturnType());
+ ASSERT_EQ(conv->Parameters().size(), 1u);
+ EXPECT_EQ(conv->Parameters()[0]->Type(), TypeOf(arg));
+ break;
+ }
+ }
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
ConversionConstructorValidTest,
@@ -408,7 +438,7 @@
"'array<f32, 4>'");
}
-} // namespace ConversionConstructorTest
+} // namespace ConversionConstructTest
namespace ArrayConstructor {
@@ -418,7 +448,15 @@
auto* tc = array<u32, 10>();
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve());
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ EXPECT_TRUE(call->Type()->Is<sem::Array>());
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 0u);
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -427,7 +465,18 @@
auto* tc = array<u32, 3>(Expr(0u), Expr(10u), Expr(20u));
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve());
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ EXPECT_TRUE(call->Type()->Is<sem::Array>());
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::U32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -587,6 +636,118 @@
} // namespace ArrayConstructor
+namespace ScalarConstructor {
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_i32_Success) {
+ auto* expr = Construct<i32>(Expr(123));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::I32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_u32_Success) {
+ auto* expr = Construct<u32>(Expr(123u));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::U32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Construct_f32_Success) {
+ auto* expr = Construct<f32>(Expr(1.23f));
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::F32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_f32_to_i32_Success) {
+ auto* expr = Construct<i32>(1.23f);
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::I32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_i32_to_u32_Success) {
+ auto* expr = Construct<u32>(123);
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::U32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, Expr_Convert_u32_to_f32_Success) {
+ auto* expr = Construct<f32>(123u);
+ WrapInFunction(expr);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+
+ ASSERT_NE(TypeOf(expr), nullptr);
+ ASSERT_TRUE(TypeOf(expr)->Is<sem::F32>());
+
+ auto* call = Sem().Get(expr);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+}
+
+} // namespace ScalarConstructor
+
namespace VectorConstructor {
TEST_F(ResolverTypeConstructorValidationTest,
@@ -708,12 +869,19 @@
auto* tc = vec2<f32>();
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 0u);
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -721,12 +889,21 @@
auto* tc = vec2<f32>(1.0f, 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -734,12 +911,21 @@
auto* tc = vec2<u32>(1u, 1u);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -747,12 +933,21 @@
auto* tc = vec2<i32>(1, 1);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -760,12 +955,21 @@
auto* tc = vec2<bool>(true, false);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Bool>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -773,12 +977,20 @@
auto* tc = vec2<f32>(vec2<f32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -786,12 +998,20 @@
auto* tc = vec2<f32>(vec2<i32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -938,12 +1158,19 @@
auto* tc = vec3<f32>();
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 0u);
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -951,12 +1178,22 @@
auto* tc = vec3<f32>(1.0f, 1.0f, 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::F32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -964,12 +1201,22 @@
auto* tc = vec3<u32>(1u, 1u, 1u);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::U32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::U32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -977,12 +1224,22 @@
auto* tc = vec3<i32>(1, 1, 1);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -990,12 +1247,22 @@
auto* tc = vec3<bool>(true, false, true);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Bool>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::Bool>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1003,12 +1270,21 @@
auto* tc = vec3<f32>(vec2<f32>(), 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::F32>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1016,12 +1292,21 @@
auto* tc = vec3<f32>(1.0f, vec2<f32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::F32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Vector>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1029,12 +1314,20 @@
auto* tc = vec3<f32>(vec3<f32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1042,12 +1335,20 @@
auto* tc = vec3<f32>(vec3<i32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
+
+ auto* call = Sem().Get(tc);
+ ASSERT_NE(call, nullptr);
+ auto* ctor = call->Target()->As<sem::TypeConversion>();
+ ASSERT_NE(ctor, nullptr);
+ EXPECT_EQ(call->Type(), ctor->ReturnType());
+ ASSERT_EQ(ctor->Parameters().size(), 1u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1248,7 +1549,7 @@
auto* tc = vec4<f32>();
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1261,7 +1562,7 @@
auto* tc = vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1274,7 +1575,7 @@
auto* tc = vec4<u32>(1u, 1u, 1u, 1u);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1287,7 +1588,7 @@
auto* tc = vec4<i32>(1, 1, 1, 1);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1300,7 +1601,7 @@
auto* tc = vec4<bool>(true, false, true, false);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1313,7 +1614,7 @@
auto* tc = vec4<f32>(vec2<f32>(), 1.0f, 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1326,7 +1627,7 @@
auto* tc = vec4<f32>(1.0f, vec2<f32>(), 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1339,7 +1640,7 @@
auto* tc = vec4<f32>(1.0f, 1.0f, vec2<f32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1352,7 +1653,7 @@
auto* tc = vec4<f32>(vec2<f32>(), vec2<f32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1365,7 +1666,7 @@
auto* tc = vec4<f32>(vec3<f32>(), 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1378,7 +1679,7 @@
auto* tc = vec4<f32>(1.0f, vec3<f32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1391,7 +1692,7 @@
auto* tc = vec4<f32>(vec4<f32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1404,7 +1705,7 @@
auto* tc = vec4<f32>(vec4<i32>());
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1431,7 +1732,7 @@
auto* tc = vec4<f32>(vec3<f32>(vec2<f32>(1.0f, 1.0f), 1.0f), 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(tc), nullptr);
ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
@@ -1462,7 +1763,7 @@
auto* tc = vec3<f32>("my_vec2", "my_f32");
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1490,7 +1791,7 @@
auto* tc = Construct(Source{{12, 34}}, vec_type, 1.0f, 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1517,7 +1818,7 @@
auto* tc = vec3<f32>(Construct(Source{{12, 34}}, vec_type), 1.0f);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
} // namespace VectorConstructor
@@ -1728,7 +2029,7 @@
auto* tc = Construct(Source{{12, 40}}, matrix_type);
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_WithColumns_Success) {
@@ -1746,7 +2047,7 @@
auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_WithElements_Success) {
@@ -1763,7 +2064,7 @@
auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
@@ -1804,7 +2105,7 @@
auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeConstructorValidationTest,
@@ -1839,7 +2140,7 @@
auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
@@ -1877,7 +2178,7 @@
auto* tc = Construct(Source{}, matrix_type, std::move(args));
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest,
@@ -2044,7 +2345,7 @@
auto* s = Structure("MyInputs", {m});
auto* tc = Construct(Source{{12, 34}}, ty.Of(s));
WrapInFunction(tc);
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeConstructorValidationTest, Expr_Constructor_Struct_Empty) {
@@ -2055,7 +2356,7 @@
});
WrapInFunction(Construct(ty.Of(str)));
- EXPECT_TRUE(r()->Resolve()) << r()->error();
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
}
} // namespace StructConstructor
@@ -2070,7 +2371,7 @@
TEST_F(ResolverTypeConstructorValidationTest,
NonConstructibleType_AtomicArray) {
WrapInFunction(Call(
- "ignore", Construct(ty.array(ty.atomic(Source{{12, 34}}, ty.i32()), 4))));
+ "ignore", Construct(Source{{12, 34}}, ty.array(ty.atomic(ty.i32()), 4))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
@@ -2097,6 +2398,22 @@
EXPECT_EQ(r()->error(), "12:34 error: type is not constructible");
}
+TEST_F(ResolverTypeConstructorValidationTest, TypeConstructorAsStatement) {
+ WrapInFunction(
+ CallStmt(Construct(Source{{12, 34}}, ty.vec2<f32>(), 1.f, 2.f)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(),
+ "12:34 error: type constructor evaluated but not used");
+}
+
+TEST_F(ResolverTypeConstructorValidationTest, TypeConversionAsStatement) {
+ WrapInFunction(CallStmt(Construct(Source{{12, 34}}, ty.f32(), 1)));
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), "12:34 error: type cast evaluated but not used");
+}
+
} // namespace
} // namespace resolver
} // namespace tint
diff --git a/src/sem/type_cast.cc b/src/sem/type_conversion.cc
similarity index 73%
rename from src/sem/type_cast.cc
rename to src/sem/type_conversion.cc
index aa39c36..4f9de30 100644
--- a/src/sem/type_cast.cc
+++ b/src/sem/type_conversion.cc
@@ -12,17 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/sem/type_cast.h"
+#include "src/sem/type_conversion.h"
-TINT_INSTANTIATE_TYPEINFO(tint::sem::TypeCast);
+TINT_INSTANTIATE_TYPEINFO(tint::sem::TypeConversion);
namespace tint {
namespace sem {
-TypeCast::TypeCast(const sem::Type* type, const sem::Parameter* parameter)
+TypeConversion::TypeConversion(const sem::Type* type,
+ const sem::Parameter* parameter)
: Base(type, ParameterList{parameter}) {}
-TypeCast::~TypeCast() = default;
+TypeConversion::~TypeConversion() = default;
} // namespace sem
} // namespace tint
diff --git a/src/sem/type_cast.h b/src/sem/type_conversion.h
similarity index 74%
rename from src/sem/type_cast.h
rename to src/sem/type_conversion.h
index 4acd888..c39202a 100644
--- a/src/sem/type_cast.h
+++ b/src/sem/type_conversion.h
@@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SRC_SEM_TYPE_CAST_H_
-#define SRC_SEM_TYPE_CAST_H_
+#ifndef SRC_SEM_TYPE_CONVERSION_H_
+#define SRC_SEM_TYPE_CONVERSION_H_
#include "src/sem/call_target.h"
namespace tint {
namespace sem {
-/// TypeCast is the CallTarget for a type cast.
-class TypeCast : public Castable<TypeCast, CallTarget> {
+/// TypeConversion is the CallTarget for a type conversion (cast).
+class TypeConversion : public Castable<TypeConversion, CallTarget> {
public:
/// Constructor
/// @param type the target type of the cast
/// @param parameter the type cast parameter
- TypeCast(const sem::Type* type, const sem::Parameter* parameter);
+ TypeConversion(const sem::Type* type, const sem::Parameter* parameter);
/// Destructor
- ~TypeCast() override;
+ ~TypeConversion() override;
/// @returns the cast source type
const sem::Type* Source() const { return Parameters()[0]->Type(); }
@@ -41,4 +41,4 @@
} // namespace sem
} // namespace tint
-#endif // SRC_SEM_TYPE_CAST_H_
+#endif // SRC_SEM_TYPE_CONVERSION_H_
diff --git a/src/transform/external_texture_transform.cc b/src/transform/external_texture_transform.cc
index 3b540ad..4dcf1fe 100644
--- a/src/transform/external_texture_transform.cc
+++ b/src/transform/external_texture_transform.cc
@@ -78,7 +78,7 @@
// Replace the call with another that has the same parameters in
// addition to a level parameter (always zero for external
// textures).
- auto* exp = ctx.Clone(call_expr->func);
+ auto* exp = ctx.Clone(call_expr->target.name);
auto* externalTextureParam = ctx.Clone(call_expr->args[0]);
ast::ExpressionList params;
diff --git a/src/transform/fold_constants.cc b/src/transform/fold_constants.cc
index 994bc1e..215c5bf 100644
--- a/src/transform/fold_constants.cc
+++ b/src/transform/fold_constants.cc
@@ -19,7 +19,10 @@
#include <vector>
#include "src/program_builder.h"
+#include "src/sem/call.h"
#include "src/sem/expression.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants);
@@ -32,26 +35,25 @@
void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
- auto* sem = ctx.src->Sem().Get(expr);
- if (!sem) {
+ auto* call = ctx.src->Sem().Get<sem::Call>(expr);
+ if (!call) {
return nullptr;
}
- auto value = sem->ConstantValue();
+ auto value = call->ConstantValue();
if (!value.IsValid()) {
return nullptr;
}
- auto* ty = sem->Type();
+ auto* ty = call->Type();
- auto* ctor = expr->As<ast::TypeConstructorExpression>();
- if (!ctor) {
+ if (!call->Target()->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
return nullptr;
}
// If original ctor expression had no init values, don't replace the
// expression
- if (ctor->values.size() == 0) {
+ if (call->Arguments().empty()) {
return nullptr;
}
@@ -68,7 +70,7 @@
// create it with 3. So what we do is construct with vec_size args,
// except if the original vector was single-value initialized, in
// which case, we only construct with one arg again.
- uint32_t ctor_size = (ctor->values.size() == 1) ? 1 : vec_size;
+ uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size;
ast::ExpressionList ctors;
for (uint32_t i = 0; i < ctor_size; ++i) {
diff --git a/src/transform/module_scope_var_to_entry_point_param.cc b/src/transform/module_scope_var_to_entry_point_param.cc
index 0efb0b4..ba61160 100644
--- a/src/transform/module_scope_var_to_entry_point_param.cc
+++ b/src/transform/module_scope_var_to_entry_point_param.cc
@@ -307,7 +307,8 @@
// Pass the variables as pointers to any functions that need them.
for (auto* call : calls_to_replace[func_ast]) {
- auto* target = ctx.src->AST().Functions().Find(call->func->symbol);
+ auto* target =
+ ctx.src->AST().Functions().Find(call->target.name->symbol);
auto* target_sem = ctx.src->Sem().Get(target);
// Add new arguments for any variables that are needed by the callee.
diff --git a/src/transform/pad_array_elements.cc b/src/transform/pad_array_elements.cc
index ca23096..a14ac7b 100644
--- a/src/transform/pad_array_elements.cc
+++ b/src/transform/pad_array_elements.cc
@@ -19,7 +19,9 @@
#include "src/program_builder.h"
#include "src/sem/array.h"
+#include "src/sem/call.h"
#include "src/sem/expression.h"
+#include "src/sem/type_constructor.h"
#include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements);
@@ -131,26 +133,29 @@
// Fix up array constructors so `A(1,2)` becomes
// `A(padded(1), padded(2))`
- ctx.ReplaceAll([&](const ast::TypeConstructorExpression* ctor)
- -> const ast::Expression* {
- if (auto* array =
- tint::As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) {
- if (auto p = pad(array)) {
- auto* arr_ty = p();
- auto el_typename = arr_ty->type->As<ast::TypeName>()->name;
+ ctx.ReplaceAll(
+ [&](const ast::CallExpression* expr) -> const ast::Expression* {
+ auto* call = sem.Get(expr);
+ if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
+ if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
+ if (auto p = pad(array)) {
+ auto* arr_ty = p();
+ auto el_typename = arr_ty->type->As<ast::TypeName>()->name;
- ast::ExpressionList args;
- args.reserve(ctor->values.size());
- for (auto* arg : ctor->values) {
- args.emplace_back(ctx.dst->Construct(
- ctx.dst->create<ast::TypeName>(el_typename), ctx.Clone(arg)));
+ ast::ExpressionList args;
+ args.reserve(call->Arguments().size());
+ for (auto* arg : call->Arguments()) {
+ auto* val = ctx.Clone(arg->Declaration());
+ args.emplace_back(ctx.dst->Construct(
+ ctx.dst->create<ast::TypeName>(el_typename), val));
+ }
+
+ return ctx.dst->Construct(arr_ty, args);
+ }
+ }
}
-
- return ctx.dst->Construct(arr_ty, args);
- }
- }
- return nullptr;
- });
+ return nullptr;
+ });
ctx.Clone();
}
diff --git a/src/transform/promote_initializers_to_const_var.cc b/src/transform/promote_initializers_to_const_var.cc
index 4128d6f..9a9001a 100644
--- a/src/transform/promote_initializers_to_const_var.cc
+++ b/src/transform/promote_initializers_to_const_var.cc
@@ -18,8 +18,10 @@
#include "src/program_builder.h"
#include "src/sem/block_statement.h"
+#include "src/sem/call.h"
#include "src/sem/expression.h"
#include "src/sem/statement.h"
+#include "src/sem/type_constructor.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToConstVar);
@@ -50,14 +52,12 @@
// pointer can be passed to the parent's constructor.
for (auto* src_node : ctx.src->ASTNodes().Objects()) {
- if (auto* src_init = src_node->As<ast::TypeConstructorExpression>()) {
- auto* src_sem_expr = ctx.src->Sem().Get(src_init);
- if (!src_sem_expr) {
- TINT_ICE(Transform, ctx.dst->Diagnostics())
- << "ast::TypeConstructorExpression has no semantic expression node";
+ if (auto* src_init = src_node->As<ast::CallExpression>()) {
+ auto* call = ctx.src->Sem().Get(src_init);
+ if (!call->Target()->Is<sem::TypeConstructor>()) {
continue;
}
- auto* src_sem_stmt = src_sem_expr->Stmt();
+ auto* src_sem_stmt = call->Stmt();
if (!src_sem_stmt) {
// Expression is outside of a statement. This usually means the
// expression is part of a global (module-scope) constant declaration.
@@ -76,12 +76,12 @@
}
}
- auto* src_ty = src_sem_expr->Type();
+ auto* src_ty = call->Type();
if (src_ty->IsAnyOf<sem::Array, sem::Struct>()) {
// Create a new symbol for the constant
auto dst_symbol = ctx.dst->Sym();
// Clone the type
- auto* dst_ty = ctx.Clone(src_init->type);
+ auto* dst_ty = CreateASTTypeFor(ctx, call->Type());
// Clone the initializer
auto* dst_init = ctx.Clone(src_init);
// Construct the constant that holds the hoisted initializer
diff --git a/src/transform/promote_initializers_to_const_var_test.cc b/src/transform/promote_initializers_to_const_var_test.cc
index 23fe3e3..2af5404 100644
--- a/src/transform/promote_initializers_to_const_var_test.cc
+++ b/src/transform/promote_initializers_to_const_var_test.cc
@@ -30,7 +30,7 @@
var f1 : f32 = 2.0;
var f2 : f32 = 3.0;
var f3 : f32 = 4.0;
- var i : f32 = array<f32, 4>(f0, f1, f2, f3)[2];
+ var i : f32 = array<f32, 4u>(f0, f1, f2, f3)[2];
}
)";
@@ -41,7 +41,7 @@
var f1 : f32 = 2.0;
var f2 : f32 = 3.0;
var f3 : f32 = 4.0;
- let tint_symbol : array<f32, 4> = array<f32, 4>(f0, f1, f2, f3);
+ let tint_symbol : array<f32, 4u> = array<f32, 4u>(f0, f1, f2, f3);
var i : f32 = tint_symbol[2];
}
)";
@@ -88,16 +88,16 @@
auto* src = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
- var i : f32 = array<array<f32, 2>, 2>(array<f32, 2>(1.0, 2.0), array<f32, 2>(3.0, 4.0))[0][1];
+ var i : f32 = array<array<f32, 2u>, 2u>(array<f32, 2u>(1.0, 2.0), array<f32, 2u>(3.0, 4.0))[0][1];
}
)";
auto* expect = R"(
[[stage(compute), workgroup_size(1)]]
fn main() {
- let tint_symbol : array<f32, 2> = array<f32, 2>(1.0, 2.0);
- let tint_symbol_1 : array<f32, 2> = array<f32, 2>(3.0, 4.0);
- let tint_symbol_2 : array<array<f32, 2>, 2> = array<array<f32, 2>, 2>(tint_symbol, tint_symbol_1);
+ let tint_symbol : array<f32, 2u> = array<f32, 2u>(1.0, 2.0);
+ let tint_symbol_1 : array<f32, 2u> = array<f32, 2u>(3.0, 4.0);
+ let tint_symbol_2 : array<array<f32, 2u>, 2u> = array<array<f32, 2u>, 2u>(tint_symbol, tint_symbol_1);
var i : f32 = tint_symbol_2[0][1];
}
)";
@@ -165,12 +165,12 @@
};
struct S2 {
- a : array<S1, 3>;
+ a : array<S1, 3u>;
};
[[stage(compute), workgroup_size(1)]]
fn main() {
- var x : i32 = S2(array<S1, 3>(S1(1), S1(2), S1(3))).a[1].a;
+ var x : i32 = S2(array<S1, 3u>(S1(1), S1(2), S1(3))).a[1].a;
}
)";
@@ -180,7 +180,7 @@
};
struct S2 {
- a : array<S1, 3>;
+ a : array<S1, 3u>;
};
[[stage(compute), workgroup_size(1)]]
@@ -188,7 +188,7 @@
let tint_symbol : S1 = S1(1);
let tint_symbol_1 : S1 = S1(2);
let tint_symbol_2 : S1 = S1(3);
- let tint_symbol_3 : array<S1, 3> = array<S1, 3>(tint_symbol, tint_symbol_1, tint_symbol_2);
+ let tint_symbol_3 : array<S1, 3u> = array<S1, 3u>(tint_symbol, tint_symbol_1, tint_symbol_2);
let tint_symbol_4 : S2 = S2(tint_symbol_3);
var x : i32 = tint_symbol_4.a[1].a;
}
@@ -209,11 +209,11 @@
[[stage(compute), workgroup_size(1)]]
fn main() {
- var local_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
+ var local_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
var local_str : S = S(1, 2.0, 3);
}
-let module_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
+let module_arr : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
let module_str : S = S(1, 2.0, 3);
)";
diff --git a/src/transform/renamer.cc b/src/transform/renamer.cc
index 1fa6f1f..01a115e 100644
--- a/src/transform/renamer.cc
+++ b/src/transform/renamer.cc
@@ -1285,7 +1285,7 @@
continue;
}
if (sem->Target()->Is<sem::Intrinsic>()) {
- preserve.emplace(call->func);
+ preserve.emplace(call->target.name);
}
}
}
diff --git a/src/transform/vectorize_scalar_matrix_constructors.cc b/src/transform/vectorize_scalar_matrix_constructors.cc
index 3c7d97e..dc4ae12 100644
--- a/src/transform/vectorize_scalar_matrix_constructors.cc
+++ b/src/transform/vectorize_scalar_matrix_constructors.cc
@@ -17,7 +17,9 @@
#include <utility>
#include "src/program_builder.h"
+#include "src/sem/call.h"
#include "src/sem/expression.h"
+#include "src/sem/type_constructor.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixConstructors);
@@ -33,38 +35,44 @@
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
const DataMap&,
DataMap&) {
- ctx.ReplaceAll([&](const ast::TypeConstructorExpression* constructor)
- -> const ast::TypeConstructorExpression* {
- // Check if this is a matrix constructor with scalar arguments.
- auto* mat_type = ctx.src->Sem().Get(constructor->type)->As<sem::Matrix>();
- if (!mat_type) {
- return nullptr;
- }
- if (constructor->values.size() == 0) {
- return nullptr;
- }
- if (!ctx.src->Sem().Get(constructor->values[0])->Type()->is_scalar()) {
- return nullptr;
- }
+ ctx.ReplaceAll(
+ [&](const ast::CallExpression* expr) -> const ast::CallExpression* {
+ auto* call = ctx.src->Sem().Get(expr);
+ auto* ty_ctor = call->Target()->As<sem::TypeConstructor>();
+ if (!ty_ctor) {
+ return nullptr;
+ }
+ // Check if this is a matrix constructor with scalar arguments.
+ auto* mat_type = call->Type()->As<sem::Matrix>();
+ if (!mat_type) {
+ return nullptr;
+ }
- // Build a list of vector expressions for each column.
- ast::ExpressionList columns;
- for (uint32_t c = 0; c < mat_type->columns(); c++) {
- // Build a list of scalar expressions for each value in the column.
- ast::ExpressionList row_values;
- for (uint32_t r = 0; r < mat_type->rows(); r++) {
- row_values.push_back(
- ctx.Clone(constructor->values[c * mat_type->rows() + r]));
- }
+ auto& args = call->Arguments();
+ if (args.size() == 0) {
+ return nullptr;
+ }
+ if (!args[0]->Type()->is_scalar()) {
+ return nullptr;
+ }
- // Construct the column vector.
- auto* col = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()),
- mat_type->rows(), row_values);
- columns.push_back(col);
- }
+ // Build a list of vector expressions for each column.
+ ast::ExpressionList columns;
+ for (uint32_t c = 0; c < mat_type->columns(); c++) {
+ // Build a list of scalar expressions for each value in the column.
+ ast::ExpressionList row_values;
+ for (uint32_t r = 0; r < mat_type->rows(); r++) {
+ row_values.push_back(
+ ctx.Clone(args[c * mat_type->rows() + r]->Declaration()));
+ }
- return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
- });
+ // Construct the column vector.
+ auto* col = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()),
+ mat_type->rows(), row_values);
+ columns.push_back(col);
+ }
+ return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
+ });
ctx.Clone();
}
diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc
index bafe5ca..3403fcc 100644
--- a/src/transform/vertex_pulling.cc
+++ b/src/transform/vertex_pulling.cc
@@ -715,8 +715,8 @@
LoadPrimitive(array_base, primitive_offset, buffer, base_format));
}
- return ctx.dst->Construct(
- ctx.dst->create<ast::Vector>(base_type, count), std::move(expr_list));
+ return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
+ std::move(expr_list));
}
/// Process a non-struct entry point parameter.
diff --git a/src/transform/wrap_arrays_in_structs.cc b/src/transform/wrap_arrays_in_structs.cc
index f6c1268..f034e33 100644
--- a/src/transform/wrap_arrays_in_structs.cc
+++ b/src/transform/wrap_arrays_in_structs.cc
@@ -18,8 +18,11 @@
#include "src/program_builder.h"
#include "src/sem/array.h"
+#include "src/sem/call.h"
#include "src/sem/expression.h"
+#include "src/sem/type_constructor.h"
#include "src/utils/get_or_create.h"
+#include "src/utils/transform.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WrapArraysInStructs);
@@ -74,21 +77,28 @@
});
// Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))`
- ctx.ReplaceAll([&](const ast::TypeConstructorExpression* ctor)
- -> const ast::Expression* {
- if (auto* array =
- ::tint::As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) {
- if (auto w = wrapper(array)) {
- // Wrap the array type constructor with another constructor for
- // the wrapper
- auto* wrapped_array_ty = ctx.Clone(ctor->type);
- auto* array_ty = w.array_type(ctx);
- auto* arr_ctor = ctx.dst->Construct(array_ty, ctx.Clone(ctor->values));
- return ctx.dst->Construct(wrapped_array_ty, arr_ctor);
- }
- }
- return nullptr;
- });
+ ctx.ReplaceAll(
+ [&](const ast::CallExpression* expr) -> const ast::Expression* {
+ if (auto* call = sem.Get(expr)) {
+ if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
+ if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
+ if (auto w = wrapper(array)) {
+ // Wrap the array type constructor with another constructor for
+ // the wrapper
+ auto* wrapped_array_ty = ctx.dst->ty.type_name(w.wrapper_name);
+ auto* array_ty = w.array_type(ctx);
+ auto args = utils::Transform(
+ call->Arguments(), [&](const tint::sem::Expression* s) {
+ return ctx.Clone(s->Declaration());
+ });
+ auto* arr_ctor = ctx.dst->Construct(array_ty, args);
+ return ctx.dst->Construct(wrapped_array_ty, arr_ctor);
+ }
+ }
+ }
+ }
+ return nullptr;
+ });
ctx.Clone();
}
diff --git a/src/writer/append_vector.cc b/src/writer/append_vector.cc
index 5059690..a7a8728 100644
--- a/src/writer/append_vector.cc
+++ b/src/writer/append_vector.cc
@@ -15,34 +15,66 @@
#include "src/writer/append_vector.h"
#include <utility>
+#include <vector>
+#include "src/sem/call.h"
#include "src/sem/expression.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
+#include "src/utils/transform.h"
namespace tint {
namespace writer {
namespace {
-const ast::TypeConstructorExpression* AsVectorConstructor(
- ProgramBuilder* b,
- const ast::Expression* expr) {
- if (auto* constructor = expr->As<ast::TypeConstructorExpression>()) {
- if (b->TypeOf(constructor)->Is<sem::Vector>()) {
- return constructor;
+struct VectorConstructorInfo {
+ const sem::Call* call = nullptr;
+ const sem::TypeConstructor* ctor = nullptr;
+ operator bool() const { return call != nullptr; }
+};
+VectorConstructorInfo AsVectorConstructor(const sem::Expression* expr) {
+ if (auto* call = expr->As<sem::Call>()) {
+ if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
+ if (ctor->ReturnType()->Is<sem::Vector>()) {
+ return {call, ctor};
+ }
}
}
- return nullptr;
+ return {};
+}
+
+const sem::Expression* Zero(ProgramBuilder& b,
+ const sem::Type* ty,
+ const sem::Statement* stmt) {
+ const ast::Expression* expr = nullptr;
+ if (ty->Is<sem::I32>()) {
+ expr = b.Expr(0);
+ } else if (ty->Is<sem::U32>()) {
+ expr = b.Expr(0u);
+ } else if (ty->Is<sem::F32>()) {
+ expr = b.Expr(0.0f);
+ } else if (ty->Is<sem::Bool>()) {
+ expr = b.Expr(false);
+ } else {
+ TINT_UNREACHABLE(Writer, b.Diagnostics())
+ << "unsupported vector element type: " << ty->TypeInfo().name;
+ return nullptr;
+ }
+ auto* sem = b.create<sem::Expression>(expr, ty, stmt, sem::Constant{});
+ b.Sem().Add(expr, sem);
+ return sem;
}
} // namespace
-const ast::TypeConstructorExpression* AppendVector(
- ProgramBuilder* b,
- const ast::Expression* vector,
- const ast::Expression* scalar) {
+const sem::Call* AppendVector(ProgramBuilder* b,
+ const ast::Expression* vector_ast,
+ const ast::Expression* scalar_ast) {
uint32_t packed_size;
const sem::Type* packed_el_sem_ty;
- auto* vector_sem = b->Sem().Get(vector);
+ auto* vector_sem = b->Sem().Get(vector_ast);
+ auto* scalar_sem = b->Sem().Get(scalar_ast);
auto* vector_ty = vector_sem->Type()->UnwrapRef();
if (auto* vec = vector_ty->As<sem::Vector>()) {
packed_size = vec->Width() + 1;
@@ -52,15 +84,15 @@
packed_el_sem_ty = vector_ty;
}
- const ast::Type* packed_el_ty = nullptr;
+ const ast::Type* packed_el_ast_ty = nullptr;
if (packed_el_sem_ty->Is<sem::I32>()) {
- packed_el_ty = b->create<ast::I32>();
+ packed_el_ast_ty = b->create<ast::I32>();
} else if (packed_el_sem_ty->Is<sem::U32>()) {
- packed_el_ty = b->create<ast::U32>();
+ packed_el_ast_ty = b->create<ast::U32>();
} else if (packed_el_sem_ty->Is<sem::F32>()) {
- packed_el_ty = b->create<ast::F32>();
+ packed_el_ast_ty = b->create<ast::F32>();
} else if (packed_el_sem_ty->Is<sem::Bool>()) {
- packed_el_ty = b->create<ast::Bool>();
+ packed_el_ast_ty = b->create<ast::Bool>();
} else {
TINT_UNREACHABLE(Writer, b->Diagnostics())
<< "unsupported vector element type: "
@@ -69,7 +101,7 @@
auto* statement = vector_sem->Stmt();
- auto* packed_ty = b->create<ast::Vector>(packed_el_ty, packed_size);
+ auto* packed_ast_ty = b->create<ast::Vector>(packed_el_ast_ty, packed_size);
auto* packed_sem_ty = b->create<sem::Vector>(packed_el_sem_ty, packed_size);
// If the coordinates are already passed in a vector constructor, with only
@@ -80,61 +112,61 @@
// The other cases for a nested vector constructor are when it is used
// to convert a vector of a different type, e.g. vec2<i32>(vec2<u32>()).
// In that case, preserve the original argument, or you'll get a type error.
- ast::ExpressionList packed;
- if (auto* vc = AsVectorConstructor(b, vector)) {
- const auto num_supplied = vc->values.size();
+
+ std::vector<const sem::Expression*> packed;
+ if (auto vc = AsVectorConstructor(vector_sem)) {
+ const auto num_supplied = vc.call->Arguments().size();
if (num_supplied == 0) {
// Zero-value vector constructor. Populate with zeros
- auto buildZero = [&]() -> const ast::LiteralExpression* {
- if (packed_el_sem_ty->Is<sem::I32>()) {
- return b->Expr(0);
- } else if (packed_el_sem_ty->Is<sem::U32>()) {
- return b->Expr(0u);
- } else if (packed_el_sem_ty->Is<sem::F32>()) {
- return b->Expr(0.0f);
- } else if (packed_el_sem_ty->Is<sem::Bool>()) {
- return b->Expr(false);
- } else {
- TINT_UNREACHABLE(Writer, b->Diagnostics())
- << "unsupported vector element type: "
- << packed_el_sem_ty->TypeInfo().name;
- }
- return nullptr;
- };
-
for (uint32_t i = 0; i < packed_size - 1; i++) {
- auto* zero = buildZero();
- b->Sem().Add(
- zero, b->create<sem::Expression>(zero, packed_el_sem_ty, statement,
- sem::Constant{}));
+ auto* zero = Zero(*b, packed_el_sem_ty, statement);
packed.emplace_back(zero);
}
} else if (num_supplied + 1 == packed_size) {
// All vector components were supplied as scalars. Pass them through.
- packed = vc->values;
+ packed = vc.call->Arguments();
}
}
if (packed.empty()) {
// The special cases didn't occur. Use the vector argument as-is.
- packed.emplace_back(vector);
+ packed.emplace_back(vector_sem);
}
- if (packed_el_sem_ty != b->TypeOf(scalar)->UnwrapRef()) {
+
+ if (packed_el_sem_ty != scalar_sem->Type()->UnwrapRef()) {
// Cast scalar to the vector element type
- auto* scalar_cast = b->Construct(packed_el_ty, scalar);
- b->Sem().Add(scalar_cast,
- b->create<sem::Expression>(scalar_cast, packed_el_sem_ty,
- statement, sem::Constant{}));
- packed.emplace_back(scalar_cast);
+ auto* scalar_cast_ast = b->Construct(packed_el_ast_ty, scalar_ast);
+ auto* scalar_cast_target = b->create<sem::TypeConversion>(
+ packed_el_sem_ty,
+ b->create<sem::Parameter>(nullptr, 0, scalar_sem->Type()->UnwrapRef(),
+ ast::StorageClass::kNone,
+ ast::Access::kUndefined));
+ auto* scalar_cast_sem =
+ b->create<sem::Call>(scalar_cast_ast, scalar_cast_target,
+ std::vector<const sem::Expression*>{scalar_sem},
+ statement, sem::Constant{});
+ b->Sem().Add(scalar_cast_ast, scalar_cast_sem);
+ packed.emplace_back(scalar_cast_sem);
} else {
- packed.emplace_back(scalar);
+ packed.emplace_back(scalar_sem);
}
- auto* constructor = b->Construct(packed_ty, std::move(packed));
- b->Sem().Add(constructor,
- b->create<sem::Expression>(constructor, packed_sem_ty, statement,
- sem::Constant{}));
-
- return constructor;
+ auto* constructor_ast = b->Construct(
+ packed_ast_ty, utils::Transform(packed, [&](const sem::Expression* expr) {
+ return expr->Declaration();
+ }));
+ auto* constructor_target = b->create<sem::TypeConstructor>(
+ packed_sem_ty,
+ utils::Transform(packed,
+ [&](const tint::sem::Expression* arg,
+ size_t i) -> const sem::Parameter* {
+ return b->create<sem::Parameter>(
+ nullptr, i, arg->Type()->UnwrapRef(),
+ ast::StorageClass::kNone, ast::Access::kUndefined);
+ }));
+ auto* constructor_sem = b->create<sem::Call>(
+ constructor_ast, constructor_target, packed, statement, sem::Constant{});
+ b->Sem().Add(constructor_ast, constructor_sem);
+ return constructor_sem;
}
} // namespace writer
diff --git a/src/writer/append_vector.h b/src/writer/append_vector.h
index e95d4b4..5e28271 100644
--- a/src/writer/append_vector.h
+++ b/src/writer/append_vector.h
@@ -20,8 +20,8 @@
namespace tint {
namespace ast {
+class CallExpression;
class Expression;
-class TypeConstructorExpression;
} // namespace ast
namespace writer {
@@ -36,10 +36,9 @@
/// @param scalar the scalar to append to the vector. Must be a scalar.
/// @returns a vector expression containing the elements of `vector` followed by
/// the single element of `scalar` cast to the `vector` element type.
-const ast::TypeConstructorExpression* AppendVector(
- ProgramBuilder* builder,
- const ast::Expression* vector,
- const ast::Expression* scalar);
+const sem::Call* AppendVector(ProgramBuilder* builder,
+ const ast::Expression* vector,
+ const ast::Expression* scalar);
} // namespace writer
} // namespace tint
diff --git a/src/writer/append_vector_test.cc b/src/writer/append_vector_test.cc
index 67030e8..e348d8d 100644
--- a/src/writer/append_vector_test.cc
+++ b/src/writer/append_vector_test.cc
@@ -15,6 +15,7 @@
#include "src/writer/append_vector.h"
#include "src/program_builder.h"
#include "src/resolver/resolver.h"
+#include "src/sem/type_constructor.h"
#include "gtest/gtest.h"
@@ -24,6 +25,7 @@
class AppendVectorTest : public ::testing::Test, public ProgramBuilder {};
+// AppendVector(vec2<i32>(1, 2), 3) -> vec3<i32>(1, 2, 3)
TEST_F(AppendVectorTest, Vec2i32_i32) {
auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2);
@@ -34,15 +36,36 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 3u);
- EXPECT_EQ(vec_123->values[0], scalar_1);
- EXPECT_EQ(vec_123->values[1], scalar_2);
- EXPECT_EQ(vec_123->values[2], scalar_3);
+ ASSERT_EQ(vec_123->args.size(), 3u);
+ EXPECT_EQ(vec_123->args[0], scalar_1);
+ EXPECT_EQ(vec_123->args[1], scalar_2);
+ EXPECT_EQ(vec_123->args[2], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 3u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
+ EXPECT_EQ(call->Arguments()[2], Sem().Get(scalar_3));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
}
+// AppendVector(vec2<i32>(1, 2), 3u) -> vec3<i32>(1, 2, i32(3u))
TEST_F(AppendVectorTest, Vec2i32_u32) {
auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2);
@@ -53,19 +76,41 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 3u);
- EXPECT_EQ(vec_123->values[0], scalar_1);
- EXPECT_EQ(vec_123->values[1], scalar_2);
- auto* u32_to_i32 = vec_123->values[2]->As<ast::TypeConstructorExpression>();
+ ASSERT_EQ(vec_123->args.size(), 3u);
+ EXPECT_EQ(vec_123->args[0], scalar_1);
+ EXPECT_EQ(vec_123->args[1], scalar_2);
+ auto* u32_to_i32 = vec_123->args[2]->As<ast::CallExpression>();
ASSERT_NE(u32_to_i32, nullptr);
- EXPECT_TRUE(u32_to_i32->type->Is<ast::I32>());
- ASSERT_EQ(u32_to_i32->values.size(), 1u);
- EXPECT_EQ(u32_to_i32->values[0], scalar_3);
+ EXPECT_TRUE(u32_to_i32->target.type->Is<ast::I32>());
+ ASSERT_EQ(u32_to_i32->args.size(), 1u);
+ EXPECT_EQ(u32_to_i32->args[0], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 3u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
+ EXPECT_EQ(call->Arguments()[2], Sem().Get(u32_to_i32));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
}
+// AppendVector(vec2<i32>(vec2<u32>(1u, 2u)), 3u) ->
+// vec3<i32>(vec2<i32>(vec2<u32>(1u, 2u)), i32(3u))
TEST_F(AppendVectorTest, Vec2i32FromVec2u32_u32) {
auto* scalar_1 = Expr(1u);
auto* scalar_2 = Expr(2u);
@@ -77,26 +122,45 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
- ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 2u);
- auto* v2u32_to_v2i32 =
- vec_123->values[0]->As<ast::TypeConstructorExpression>();
- ASSERT_NE(v2u32_to_v2i32, nullptr);
- ASSERT_TRUE(v2u32_to_v2i32->type->Is<ast::Vector>());
- EXPECT_EQ(v2u32_to_v2i32->type->As<ast::Vector>()->width, 2u);
- EXPECT_TRUE(v2u32_to_v2i32->type->As<ast::Vector>()->type->Is<ast::I32>());
- EXPECT_EQ(v2u32_to_v2i32->values.size(), 1u);
- EXPECT_EQ(v2u32_to_v2i32->values[0], uvec_12);
+ auto* append = AppendVector(this, vec_12, scalar_3);
- auto* u32_to_i32 = vec_123->values[1]->As<ast::TypeConstructorExpression>();
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
+ ASSERT_NE(vec_123, nullptr);
+ ASSERT_EQ(vec_123->args.size(), 2u);
+ auto* v2u32_to_v2i32 = vec_123->args[0]->As<ast::CallExpression>();
+ ASSERT_NE(v2u32_to_v2i32, nullptr);
+ ASSERT_TRUE(v2u32_to_v2i32->target.type->Is<ast::Vector>());
+ EXPECT_EQ(v2u32_to_v2i32->target.type->As<ast::Vector>()->width, 2u);
+ EXPECT_TRUE(
+ v2u32_to_v2i32->target.type->As<ast::Vector>()->type->Is<ast::I32>());
+ EXPECT_EQ(v2u32_to_v2i32->args.size(), 1u);
+ EXPECT_EQ(v2u32_to_v2i32->args[0], uvec_12);
+
+ auto* u32_to_i32 = vec_123->args[1]->As<ast::CallExpression>();
ASSERT_NE(u32_to_i32, nullptr);
- EXPECT_TRUE(u32_to_i32->type->Is<ast::I32>());
- ASSERT_EQ(u32_to_i32->values.size(), 1u);
- EXPECT_EQ(u32_to_i32->values[0], scalar_3);
+ EXPECT_TRUE(u32_to_i32->target.type->Is<ast::I32>());
+ ASSERT_EQ(u32_to_i32->args.size(), 1u);
+ EXPECT_EQ(u32_to_i32->args[0], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 2u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(u32_to_i32));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
}
+// AppendVector(vec2<i32>(1, 2), 3.0f) -> vec3<i32>(1, 2, i32(3.0f))
TEST_F(AppendVectorTest, Vec2i32_f32) {
auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2);
@@ -107,40 +171,84 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 3u);
- EXPECT_EQ(vec_123->values[0], scalar_1);
- EXPECT_EQ(vec_123->values[1], scalar_2);
- auto* f32_to_i32 = vec_123->values[2]->As<ast::TypeConstructorExpression>();
+ ASSERT_EQ(vec_123->args.size(), 3u);
+ EXPECT_EQ(vec_123->args[0], scalar_1);
+ EXPECT_EQ(vec_123->args[1], scalar_2);
+ auto* f32_to_i32 = vec_123->args[2]->As<ast::CallExpression>();
ASSERT_NE(f32_to_i32, nullptr);
- EXPECT_TRUE(f32_to_i32->type->Is<ast::I32>());
- ASSERT_EQ(f32_to_i32->values.size(), 1u);
- EXPECT_EQ(f32_to_i32->values[0], scalar_3);
+ EXPECT_TRUE(f32_to_i32->target.type->Is<ast::I32>());
+ ASSERT_EQ(f32_to_i32->args.size(), 1u);
+ EXPECT_EQ(f32_to_i32->args[0], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 3u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
+ EXPECT_EQ(call->Arguments()[2], Sem().Get(f32_to_i32));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
}
+// AppendVector(vec3<i32>(1, 2, 3), 4) -> vec4<i32>(1, 2, 3, 4)
TEST_F(AppendVectorTest, Vec3i32_i32) {
auto* scalar_1 = Expr(1);
auto* scalar_2 = Expr(2);
auto* scalar_3 = Expr(3);
- auto* scalar_4 = Expr(3);
+ auto* scalar_4 = Expr(4);
auto* vec_123 = vec3<i32>(scalar_1, scalar_2, scalar_3);
WrapInFunction(vec_123, scalar_4);
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_1234 = AppendVector(this, vec_123, scalar_4)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_123, scalar_4);
+
+ auto* vec_1234 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_1234, nullptr);
- ASSERT_EQ(vec_1234->values.size(), 4u);
- EXPECT_EQ(vec_1234->values[0], scalar_1);
- EXPECT_EQ(vec_1234->values[1], scalar_2);
- EXPECT_EQ(vec_1234->values[2], scalar_3);
- EXPECT_EQ(vec_1234->values[3], scalar_4);
+ ASSERT_EQ(vec_1234->args.size(), 4u);
+ EXPECT_EQ(vec_1234->args[0], scalar_1);
+ EXPECT_EQ(vec_1234->args[1], scalar_2);
+ EXPECT_EQ(vec_1234->args[2], scalar_3);
+ EXPECT_EQ(vec_1234->args[3], scalar_4);
+
+ auto* call = Sem().Get(vec_1234);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 4u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
+ EXPECT_EQ(call->Arguments()[2], Sem().Get(scalar_3));
+ EXPECT_EQ(call->Arguments()[3], Sem().Get(scalar_4));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 4u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 4u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[3]->Type()->Is<sem::I32>());
}
+// AppendVector(vec_12, 3) -> vec3<i32>(vec_12, 3)
TEST_F(AppendVectorTest, Vec2i32Var_i32) {
Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate);
auto* vec_12 = Expr("vec_12");
@@ -150,14 +258,33 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 2u);
- EXPECT_EQ(vec_123->values[0], vec_12);
- EXPECT_EQ(vec_123->values[1], scalar_3);
+ ASSERT_EQ(vec_123->args.size(), 2u);
+ EXPECT_EQ(vec_123->args[0], vec_12);
+ EXPECT_EQ(vec_123->args[1], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 2u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_3));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
}
+// AppendVector(1, 2, scalar_3) -> vec3<i32>(1, 2, scalar_3)
TEST_F(AppendVectorTest, Vec2i32_i32Var) {
Global("scalar_3", ty.i32(), ast::StorageClass::kPrivate);
auto* scalar_1 = Expr(1);
@@ -169,15 +296,36 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 3u);
- EXPECT_EQ(vec_123->values[0], scalar_1);
- EXPECT_EQ(vec_123->values[1], scalar_2);
- EXPECT_EQ(vec_123->values[2], scalar_3);
+ ASSERT_EQ(vec_123->args.size(), 3u);
+ EXPECT_EQ(vec_123->args[0], scalar_1);
+ EXPECT_EQ(vec_123->args[1], scalar_2);
+ EXPECT_EQ(vec_123->args[2], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 3u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(scalar_1));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_2));
+ EXPECT_EQ(call->Arguments()[2], Sem().Get(scalar_3));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 3u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
}
+// AppendVector(vec_12, scalar_3) -> vec3<i32>(vec_12, scalar_3)
TEST_F(AppendVectorTest, Vec2i32Var_i32Var) {
Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate);
Global("scalar_3", ty.i32(), ast::StorageClass::kPrivate);
@@ -188,14 +336,33 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 2u);
- EXPECT_EQ(vec_123->values[0], vec_12);
- EXPECT_EQ(vec_123->values[1], scalar_3);
+ ASSERT_EQ(vec_123->args.size(), 2u);
+ EXPECT_EQ(vec_123->args[0], vec_12);
+ EXPECT_EQ(vec_123->args[1], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 2u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_3));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
}
+// AppendVector(vec_12, scalar_3) -> vec3<i32>(vec_12, i32(scalar_3))
TEST_F(AppendVectorTest, Vec2i32Var_f32Var) {
Global("vec_12", ty.vec2<i32>(), ast::StorageClass::kPrivate);
Global("scalar_3", ty.f32(), ast::StorageClass::kPrivate);
@@ -206,18 +373,37 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 2u);
- EXPECT_EQ(vec_123->values[0], vec_12);
- auto* f32_to_i32 = vec_123->values[1]->As<ast::TypeConstructorExpression>();
+ ASSERT_EQ(vec_123->args.size(), 2u);
+ EXPECT_EQ(vec_123->args[0], vec_12);
+ auto* f32_to_i32 = vec_123->args[1]->As<ast::CallExpression>();
ASSERT_NE(f32_to_i32, nullptr);
- EXPECT_TRUE(f32_to_i32->type->Is<ast::I32>());
- ASSERT_EQ(f32_to_i32->values.size(), 1u);
- EXPECT_EQ(f32_to_i32->values[0], scalar_3);
+ EXPECT_TRUE(f32_to_i32->target.type->Is<ast::I32>());
+ ASSERT_EQ(f32_to_i32->args.size(), 1u);
+ EXPECT_EQ(f32_to_i32->args[0], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 2u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(f32_to_i32));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
}
+// AppendVector(vec_12, scalar_3) -> vec3<bool>(vec_12, scalar_3)
TEST_F(AppendVectorTest, Vec2boolVar_boolVar) {
Global("vec_12", ty.vec2<bool>(), ast::StorageClass::kPrivate);
Global("scalar_3", ty.bool_(), ast::StorageClass::kPrivate);
@@ -228,14 +414,33 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_123 = AppendVector(this, vec_12, scalar_3)
- ->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec_12, scalar_3);
+
+ auto* vec_123 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_123, nullptr);
- ASSERT_EQ(vec_123->values.size(), 2u);
- EXPECT_EQ(vec_123->values[0], vec_12);
- EXPECT_EQ(vec_123->values[1], scalar_3);
+ ASSERT_EQ(vec_123->args.size(), 2u);
+ EXPECT_EQ(vec_123->args[0], vec_12);
+ EXPECT_EQ(vec_123->args[1], scalar_3);
+
+ auto* call = Sem().Get(vec_123);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 2u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_12));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(scalar_3));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 3u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::Bool>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 2u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::Vector>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::Bool>());
}
+// AppendVector(vec3<i32>(), 4) -> vec3<bool>(0, 0, 0, 4)
TEST_F(AppendVectorTest, ZeroVec3i32_i32) {
auto* scalar = Expr(4);
auto* vec000 = vec3<i32>();
@@ -244,16 +449,38 @@
resolver::Resolver resolver(this);
ASSERT_TRUE(resolver.Resolve()) << resolver.error();
- auto* vec_0004 =
- AppendVector(this, vec000, scalar)->As<ast::TypeConstructorExpression>();
+ auto* append = AppendVector(this, vec000, scalar);
+
+ auto* vec_0004 = As<ast::CallExpression>(append->Declaration());
ASSERT_NE(vec_0004, nullptr);
- ASSERT_EQ(vec_0004->values.size(), 4u);
+ ASSERT_EQ(vec_0004->args.size(), 4u);
for (size_t i = 0; i < 3; i++) {
- auto* literal = As<ast::SintLiteralExpression>(vec_0004->values[i]);
+ auto* literal = As<ast::SintLiteralExpression>(vec_0004->args[i]);
ASSERT_NE(literal, nullptr);
EXPECT_EQ(literal->value, 0);
}
- EXPECT_EQ(vec_0004->values[3], scalar);
+ EXPECT_EQ(vec_0004->args[3], scalar);
+
+ auto* call = Sem().Get(vec_0004);
+ ASSERT_NE(call, nullptr);
+ ASSERT_EQ(call->Arguments().size(), 4u);
+ EXPECT_EQ(call->Arguments()[0], Sem().Get(vec_0004->args[0]));
+ EXPECT_EQ(call->Arguments()[1], Sem().Get(vec_0004->args[1]));
+ EXPECT_EQ(call->Arguments()[2], Sem().Get(vec_0004->args[2]));
+ EXPECT_EQ(call->Arguments()[3], Sem().Get(scalar));
+
+ auto* ctor = call->Target()->As<sem::TypeConstructor>();
+ ASSERT_NE(ctor, nullptr);
+ ASSERT_TRUE(ctor->ReturnType()->Is<sem::Vector>());
+ EXPECT_EQ(ctor->ReturnType()->As<sem::Vector>()->Width(), 4u);
+ EXPECT_TRUE(ctor->ReturnType()->As<sem::Vector>()->type()->Is<sem::I32>());
+ EXPECT_EQ(ctor->ReturnType(), call->Type());
+
+ ASSERT_EQ(ctor->Parameters().size(), 4u);
+ EXPECT_TRUE(ctor->Parameters()[0]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[1]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[2]->Type()->Is<sem::I32>());
+ EXPECT_TRUE(ctor->Parameters()[3]->Type()->Is<sem::I32>());
}
} // namespace
diff --git a/src/writer/glsl/generator_impl.cc b/src/writer/glsl/generator_impl.cc
index 9c0f775..b7ea9e2 100644
--- a/src/writer/glsl/generator_impl.cc
+++ b/src/writer/glsl/generator_impl.cc
@@ -41,6 +41,8 @@
#include "src/sem/statement.h"
#include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
#include "src/sem/variable.h"
#include "src/transform/calculate_array_length.h"
#include "src/transform/glsl.h"
@@ -358,85 +360,49 @@
bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) {
- const auto& args = expr->args;
- auto* ident = expr->func;
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
if (auto* func = target->As<sem::Function>()) {
- if (ast::HasDecoration<
- transform::CalculateArrayLength::BufferSizeIntrinsic>(
- func->Declaration()->decorations)) {
- // Special function generated by the CalculateArrayLength transform for
- // calling X.GetDimensions(Y)
- if (!EmitExpression(out, args[0])) {
- return false;
- }
- out << ".GetDimensions(";
- if (!EmitExpression(out, args[1])) {
- return false;
- }
- out << ")";
- return true;
- }
+ return EmitFunctionCall(out, call, func);
}
-
- if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) {
- if (intrinsic->IsTexture()) {
- return EmitTextureCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
- return EmitSelectCall(out, expr);
- } else if (intrinsic->Type() == sem::IntrinsicType::kDot) {
- return EmitDotCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kModf) {
- return EmitModfCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
- return EmitFrexpCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
- return EmitIsNormalCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
- return EmitExpression(out, expr->args[0]); // [DEPRECATED]
- } else if (intrinsic->IsDataPacking()) {
- return EmitDataPackingCall(out, expr, intrinsic);
- } else if (intrinsic->IsDataUnpacking()) {
- return EmitDataUnpackingCall(out, expr, intrinsic);
- } else if (intrinsic->IsBarrier()) {
- return EmitBarrierCall(out, intrinsic);
- } else if (intrinsic->IsAtomic()) {
- return EmitWorkgroupAtomicCall(out, expr, intrinsic);
- }
- auto name = generate_builtin_name(intrinsic);
- if (name.empty()) {
- return false;
- }
-
- out << name << "(";
-
- bool first = true;
- for (auto* arg : args) {
- if (!first) {
- out << ", ";
- }
- first = false;
-
- if (!EmitExpression(out, arg)) {
- return false;
- }
- }
-
- out << ")";
- return true;
+ if (auto* intrinsic = target->As<sem::Intrinsic>()) {
+ return EmitIntrinsicCall(out, call, intrinsic);
}
+ if (auto* cast = target->As<sem::TypeConversion>()) {
+ return EmitTypeConversion(out, call, cast);
+ }
+ if (auto* ctor = target->As<sem::TypeConstructor>()) {
+ return EmitTypeConstructor(out, call, ctor);
+ }
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled call target: " << target->TypeInfo().name;
+ return false;
+}
+
+bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Function* func) {
+ const auto& args = call->Arguments();
+ auto* decl = call->Declaration();
+ auto* ident = decl->target.name;
auto name = builder_.Symbols().NameFor(ident->symbol);
auto caller_sym = ident->symbol;
- auto* func = builder_.AST().Functions().Find(ident->symbol);
- if (func == nullptr) {
- diagnostics_.add_error(diag::System::Writer,
- "Unable to find function: " +
- builder_.Symbols().NameFor(ident->symbol));
- return false;
+ if (ast::HasDecoration<transform::CalculateArrayLength::BufferSizeIntrinsic>(
+ func->Declaration()->decorations)) {
+ // Special function generated by the CalculateArrayLength transform for
+ // calling X.GetDimensions(Y)
+ if (!EmitExpression(out, args[0]->Declaration())) {
+ return false;
+ }
+ out << ".GetDimensions(";
+ if (!EmitExpression(out, args[1]->Declaration())) {
+ return false;
+ }
+ out << ")";
+ return true;
}
out << name << "(";
@@ -448,13 +414,141 @@
}
first = false;
- if (!EmitExpression(out, arg)) {
+ if (!EmitExpression(out, arg->Declaration())) {
return false;
}
}
out << ")";
+ return true;
+}
+bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Intrinsic* intrinsic) {
+ auto* expr = call->Declaration();
+ if (intrinsic->IsTexture()) {
+ return EmitTextureCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
+ return EmitSelectCall(out, expr);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kDot) {
+ return EmitDotCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kModf) {
+ return EmitModfCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
+ return EmitFrexpCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
+ return EmitIsNormalCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
+ return EmitExpression(out, expr->args[0]); // [DEPRECATED]
+ }
+ if (intrinsic->IsDataPacking()) {
+ return EmitDataPackingCall(out, expr, intrinsic);
+ }
+ if (intrinsic->IsDataUnpacking()) {
+ return EmitDataUnpackingCall(out, expr, intrinsic);
+ }
+ if (intrinsic->IsBarrier()) {
+ return EmitBarrierCall(out, intrinsic);
+ }
+ if (intrinsic->IsAtomic()) {
+ return EmitWorkgroupAtomicCall(out, expr, intrinsic);
+ }
+ auto name = generate_builtin_name(intrinsic);
+ if (name.empty()) {
+ return false;
+ }
+
+ out << name << "(";
+
+ bool first = true;
+ for (auto* arg : call->Arguments()) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+
+ if (!EmitExpression(out, arg->Declaration())) {
+ return false;
+ }
+ }
+
+ out << ")";
+ return true;
+}
+
+bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConversion* conv) {
+ if (!EmitType(out, conv->Target(), ast::StorageClass::kNone,
+ ast::Access::kReadWrite, "")) {
+ return false;
+ }
+ out << "(";
+
+ if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
+ return false;
+ }
+
+ out << ")";
+ return true;
+}
+
+bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConstructor* ctor) {
+ auto* type = ctor->ReturnType();
+
+ // If the type constructor is empty then we need to construct with the zero
+ // value for all components.
+ if (call->Arguments().empty()) {
+ return EmitZeroValue(out, type);
+ }
+
+ // For single-value vector initializers, swizzle the scalar to the right
+ // vector dimension using .x
+ const bool is_single_value_vector_init =
+ type->is_scalar_vector() && call->Arguments().size() == 1 &&
+ call->Arguments()[0]->Type()->UnwrapRef()->is_scalar();
+
+ auto it = structure_builders_.find(As<sem::Struct>(type));
+ if (it != structure_builders_.end()) {
+ out << it->second << "(";
+ } else {
+ if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
+ "")) {
+ return false;
+ }
+ out << "(";
+ }
+
+ if (is_single_value_vector_init) {
+ out << "(";
+ }
+
+ bool first = true;
+ for (auto* arg : call->Arguments()) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+
+ if (!EmitExpression(out, arg->Declaration())) {
+ return false;
+ }
+ }
+
+ if (is_single_value_vector_init) {
+ out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
+ }
+
+ out << ")";
return true;
}
@@ -1148,13 +1242,13 @@
builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt,
sem::Constant{}));
auto* packed = AppendVector(&builder_, vector, zero);
- return EmitExpression(out, packed);
+ return EmitExpression(out, packed->Declaration());
};
auto emit_vector_appended_with_level = [&](const ast::Expression* vector) {
if (auto* level = arg(Usage::kLevel)) {
auto* packed = AppendVector(&builder_, vector, level);
- return EmitExpression(out, packed);
+ return EmitExpression(out, packed->Declaration());
}
return emit_vector_appended_with_i32_zero(vector);
};
@@ -1164,11 +1258,11 @@
auto* packed = AppendVector(&builder_, param_coords, array_index);
if (pack_level_in_coords) {
// Then mip level needs to be appended to the coordinates.
- if (!emit_vector_appended_with_level(packed)) {
+ if (!emit_vector_appended_with_level(packed->Declaration())) {
return false;
}
} else {
- if (!EmitExpression(out, packed)) {
+ if (!EmitExpression(out, packed->Declaration())) {
return false;
}
}
@@ -1347,58 +1441,6 @@
return true;
}
-bool GeneratorImpl::EmitTypeConstructor(
- std::ostream& out,
- const ast::TypeConstructorExpression* expr) {
- auto* type = TypeOf(expr)->UnwrapRef();
-
- // If the type constructor is empty then we need to construct with the zero
- // value for all components.
- if (expr->values.empty()) {
- return EmitZeroValue(out, type);
- }
-
- // For single-value vector initializers, swizzle the scalar to the right
- // vector dimension using .x
- const bool is_single_value_vector_init =
- type->is_scalar_vector() && expr->values.size() == 1 &&
- TypeOf(expr->values[0])->UnwrapRef()->is_scalar();
-
- auto it = structure_builders_.find(As<sem::Struct>(type));
- if (it != structure_builders_.end()) {
- out << it->second << "(";
- } else {
- if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
- "")) {
- return false;
- }
- out << "(";
- }
-
- if (is_single_value_vector_init) {
- out << "(";
- }
-
- bool first = true;
- for (auto* e : expr->values) {
- if (!first) {
- out << ", ";
- }
- first = false;
-
- if (!EmitExpression(out, e)) {
- return false;
- }
- }
-
- if (is_single_value_vector_init) {
- out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
- }
-
- out << ")";
- return true;
-}
-
bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
if (!emit_continuing_()) {
return false;
@@ -1428,9 +1470,6 @@
if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c);
}
- if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
- return EmitTypeConstructor(out, c);
- }
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i);
}
diff --git a/src/writer/glsl/generator_impl.h b/src/writer/glsl/generator_impl.h
index 85522b8..c28154c 100644
--- a/src/writer/glsl/generator_impl.h
+++ b/src/writer/glsl/generator_impl.h
@@ -43,6 +43,8 @@
namespace sem {
class Call;
class Intrinsic;
+class TypeConstructor;
+class TypeConversion;
} // namespace sem
namespace writer {
@@ -100,6 +102,38 @@
/// @param expr the call expression
/// @returns true if the call expression is emitted
bool EmitCall(std::ostream& out, const ast::CallExpression* expr);
+ /// Handles generating a function call expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param function the function being called
+ /// @returns true if the expression is emitted
+ bool EmitFunctionCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Function* function);
+ /// Handles generating an intrinsic call expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param intrinsic the intrinsic being called
+ /// @returns true if the expression is emitted
+ bool EmitIntrinsicCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Intrinsic* intrinsic);
+ /// Handles generating a type conversion expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param conv the type conversion
+ /// @returns true if the expression is emitted
+ bool EmitTypeConversion(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConversion* conv);
+ /// Handles generating a type constructor expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param ctor the type constructor
+ /// @returns true if the expression is emitted
+ bool EmitTypeConstructor(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConstructor* ctor);
/// Handles generating a barrier intrinsic call
/// @param out the output of the expression stream
/// @param intrinsic the semantic information for the barrier intrinsic
@@ -192,12 +226,6 @@
/// @param stmt the discard statement
/// @returns true if the statement was successfully emitted
bool EmitDiscard(const ast::DiscardStatement* stmt);
- /// Handles emitting a type constructor
- /// @param out the output of the expression stream
- /// @param expr the type constructor expression
- /// @returns true if the constructor is emitted
- bool EmitTypeConstructor(std::ostream& out,
- const ast::TypeConstructorExpression* expr);
/// Handles a continue statement
/// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 0751248..b8352c7 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -41,6 +41,8 @@
#include "src/sem/statement.h"
#include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
#include "src/sem/variable.h"
#include "src/transform/add_empty_entry_point.h"
#include "src/transform/calculate_array_length.h"
@@ -499,7 +501,7 @@
case ast::BinaryOp::kDivide:
out << "/";
- if (auto val = program_->Sem().Get(expr->rhs)->ConstantValue()) {
+ if (auto val = builder_.Sem().Get(expr->rhs)->ConstantValue()) {
// Integer divide by zero is a DXC compile error, and undefined behavior
// in WGSL. Replace the 0 with 1.
if (val.Type()->Is<sem::I32>() && val.Elements()[0].i32 == 0) {
@@ -559,117 +561,209 @@
bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) {
- const auto& args = expr->args;
- auto* ident = expr->func;
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
if (auto* func = target->As<sem::Function>()) {
- if (ast::HasDecoration<
- transform::CalculateArrayLength::BufferSizeIntrinsic>(
- func->Declaration()->decorations)) {
- // Special function generated by the CalculateArrayLength transform for
- // calling X.GetDimensions(Y)
- if (!EmitExpression(out, args[0])) {
- return false;
- }
- out << ".GetDimensions(";
- if (!EmitExpression(out, args[1])) {
- return false;
- }
- out << ")";
- return true;
- }
-
- if (auto* intrinsic =
- ast::GetDecoration<transform::DecomposeMemoryAccess::Intrinsic>(
- func->Declaration()->decorations)) {
- switch (intrinsic->storage_class) {
- case ast::StorageClass::kUniform:
- return EmitUniformBufferAccess(out, expr, intrinsic);
- case ast::StorageClass::kStorage:
- return EmitStorageBufferAccess(out, expr, intrinsic);
- default:
- TINT_UNREACHABLE(Writer, diagnostics_)
- << "unsupported DecomposeMemoryAccess::Intrinsic storage class:"
- << intrinsic->storage_class;
- return false;
- }
- }
+ return EmitFunctionCall(out, call, func);
}
+ if (auto* intrinsic = target->As<sem::Intrinsic>()) {
+ return EmitIntrinsicCall(out, call, intrinsic);
+ }
+ if (auto* conv = target->As<sem::TypeConversion>()) {
+ return EmitTypeConversion(out, call, conv);
+ }
+ if (auto* ctor = target->As<sem::TypeConstructor>()) {
+ return EmitTypeConstructor(out, call, ctor);
+ }
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled call target: " << target->TypeInfo().name;
+ return false;
+}
- if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) {
- if (intrinsic->IsTexture()) {
- return EmitTextureCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
- return EmitSelectCall(out, expr);
- } else if (intrinsic->Type() == sem::IntrinsicType::kModf) {
- return EmitModfCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
- return EmitFrexpCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
- return EmitIsNormalCall(out, expr, intrinsic);
- } else if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
- return EmitExpression(out, expr->args[0]); // [DEPRECATED]
- } else if (intrinsic->IsDataPacking()) {
- return EmitDataPackingCall(out, expr, intrinsic);
- } else if (intrinsic->IsDataUnpacking()) {
- return EmitDataUnpackingCall(out, expr, intrinsic);
- } else if (intrinsic->IsBarrier()) {
- return EmitBarrierCall(out, intrinsic);
- } else if (intrinsic->IsAtomic()) {
- return EmitWorkgroupAtomicCall(out, expr, intrinsic);
- }
- auto name = generate_builtin_name(intrinsic);
- if (name.empty()) {
+bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Function* func) {
+ auto* expr = call->Declaration();
+
+ if (ast::HasDecoration<transform::CalculateArrayLength::BufferSizeIntrinsic>(
+ func->Declaration()->decorations)) {
+ // Special function generated by the CalculateArrayLength transform for
+ // calling X.GetDimensions(Y)
+ if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
return false;
}
-
- out << name << "(";
-
- bool first = true;
- for (auto* arg : args) {
- if (!first) {
- out << ", ";
- }
- first = false;
-
- if (!EmitExpression(out, arg)) {
- return false;
- }
+ out << ".GetDimensions(";
+ if (!EmitExpression(out, call->Arguments()[1]->Declaration())) {
+ return false;
}
-
out << ")";
return true;
}
- auto name = builder_.Symbols().NameFor(ident->symbol);
- auto caller_sym = ident->symbol;
+ if (auto* intrinsic =
+ ast::GetDecoration<transform::DecomposeMemoryAccess::Intrinsic>(
+ func->Declaration()->decorations)) {
+ switch (intrinsic->storage_class) {
+ case ast::StorageClass::kUniform:
+ return EmitUniformBufferAccess(out, expr, intrinsic);
+ case ast::StorageClass::kStorage:
+ return EmitStorageBufferAccess(out, expr, intrinsic);
+ default:
+ TINT_UNREACHABLE(Writer, diagnostics_)
+ << "unsupported DecomposeMemoryAccess::Intrinsic storage class:"
+ << intrinsic->storage_class;
+ return false;
+ }
+ }
- auto* func = builder_.AST().Functions().Find(ident->symbol);
- if (func == nullptr) {
- diagnostics_.add_error(diag::System::Writer,
- "Unable to find function: " +
- builder_.Symbols().NameFor(ident->symbol));
+ out << builder_.Symbols().NameFor(func->Declaration()->symbol) << "(";
+
+ bool first = true;
+ for (auto* arg : call->Arguments()) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+
+ if (!EmitExpression(out, arg->Declaration())) {
+ return false;
+ }
+ }
+
+ out << ")";
+ return true;
+}
+
+bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Intrinsic* intrinsic) {
+ auto* expr = call->Declaration();
+ if (intrinsic->IsTexture()) {
+ return EmitTextureCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
+ return EmitSelectCall(out, expr);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kModf) {
+ return EmitModfCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
+ return EmitFrexpCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kIsNormal) {
+ return EmitIsNormalCall(out, expr, intrinsic);
+ }
+ if (intrinsic->Type() == sem::IntrinsicType::kIgnore) {
+ return EmitExpression(out, expr->args[0]); // [DEPRECATED]
+ }
+ if (intrinsic->IsDataPacking()) {
+ return EmitDataPackingCall(out, expr, intrinsic);
+ }
+ if (intrinsic->IsDataUnpacking()) {
+ return EmitDataUnpackingCall(out, expr, intrinsic);
+ }
+ if (intrinsic->IsBarrier()) {
+ return EmitBarrierCall(out, intrinsic);
+ }
+ if (intrinsic->IsAtomic()) {
+ return EmitWorkgroupAtomicCall(out, expr, intrinsic);
+ }
+
+ auto name = generate_builtin_name(intrinsic);
+ if (name.empty()) {
return false;
}
out << name << "(";
bool first = true;
- for (auto* arg : args) {
+ for (auto* arg : call->Arguments()) {
if (!first) {
out << ", ";
}
first = false;
- if (!EmitExpression(out, arg)) {
+ if (!EmitExpression(out, arg->Declaration())) {
return false;
}
}
out << ")";
+ return true;
+}
+bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConversion* conv) {
+ if (!EmitType(out, conv->Target(), ast::StorageClass::kNone,
+ ast::Access::kReadWrite, "")) {
+ return false;
+ }
+ out << "(";
+
+ if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
+ return false;
+ }
+
+ out << ")";
+ return true;
+}
+
+bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConstructor* ctor) {
+ auto* type = call->Type();
+
+ // If the type constructor is empty then we need to construct with the zero
+ // value for all components.
+ if (call->Arguments().empty()) {
+ return EmitZeroValue(out, type);
+ }
+
+ bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
+
+ // For single-value vector initializers, swizzle the scalar to the right
+ // vector dimension using .x
+ const bool is_single_value_vector_init =
+ type->is_scalar_vector() && call->Arguments().size() == 1 &&
+ ctor->Parameters()[0]->Type()->is_scalar();
+
+ auto it = structure_builders_.find(As<sem::Struct>(type));
+ if (it != structure_builders_.end()) {
+ out << it->second << "(";
+ brackets = false;
+ } else if (brackets) {
+ out << "{";
+ } else {
+ if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
+ "")) {
+ return false;
+ }
+ out << "(";
+ }
+
+ if (is_single_value_vector_init) {
+ out << "(";
+ }
+
+ bool first = true;
+ for (auto* e : call->Arguments()) {
+ if (!first) {
+ out << ", ";
+ }
+ first = false;
+
+ if (!EmitExpression(out, e->Declaration())) {
+ return false;
+ }
+ }
+
+ if (is_single_value_vector_init) {
+ out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
+ }
+
+ out << (brackets ? "}" : ")");
return true;
}
@@ -1892,13 +1986,13 @@
builder_.Sem().Add(zero, builder_.create<sem::Expression>(zero, i32, stmt,
sem::Constant{}));
auto* packed = AppendVector(&builder_, vector, zero);
- return EmitExpression(out, packed);
+ return EmitExpression(out, packed->Declaration());
};
auto emit_vector_appended_with_level = [&](const ast::Expression* vector) {
if (auto* level = arg(Usage::kLevel)) {
auto* packed = AppendVector(&builder_, vector, level);
- return EmitExpression(out, packed);
+ return EmitExpression(out, packed->Declaration());
}
return emit_vector_appended_with_i32_zero(vector);
};
@@ -1908,11 +2002,11 @@
auto* packed = AppendVector(&builder_, param_coords, array_index);
if (pack_level_in_coords) {
// Then mip level needs to be appended to the coordinates.
- if (!emit_vector_appended_with_level(packed)) {
+ if (!emit_vector_appended_with_level(packed->Declaration())) {
return false;
}
} else {
- if (!EmitExpression(out, packed)) {
+ if (!EmitExpression(out, packed->Declaration())) {
return false;
}
}
@@ -2112,63 +2206,6 @@
return true;
}
-bool GeneratorImpl::EmitTypeConstructor(
- std::ostream& out,
- const ast::TypeConstructorExpression* expr) {
- auto* type = TypeOf(expr)->UnwrapRef();
-
- // If the type constructor is empty then we need to construct with the zero
- // value for all components.
- if (expr->values.empty()) {
- return EmitZeroValue(out, type);
- }
-
- bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
-
- // For single-value vector initializers, swizzle the scalar to the right
- // vector dimension using .x
- const bool is_single_value_vector_init =
- type->is_scalar_vector() && expr->values.size() == 1 &&
- TypeOf(expr->values[0])->UnwrapRef()->is_scalar();
-
- auto it = structure_builders_.find(As<sem::Struct>(type));
- if (it != structure_builders_.end()) {
- out << it->second << "(";
- brackets = false;
- } else if (brackets) {
- out << "{";
- } else {
- if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kReadWrite,
- "")) {
- return false;
- }
- out << "(";
- }
-
- if (is_single_value_vector_init) {
- out << "(";
- }
-
- bool first = true;
- for (auto* e : expr->values) {
- if (!first) {
- out << ", ";
- }
- first = false;
-
- if (!EmitExpression(out, e)) {
- return false;
- }
- }
-
- if (is_single_value_vector_init) {
- out << ")." << std::string(type->As<sem::Vector>()->Width(), 'x');
- }
-
- out << (brackets ? "}" : ")");
- return true;
-}
-
bool GeneratorImpl::EmitContinue(const ast::ContinueStatement*) {
if (!emit_continuing_()) {
return false;
@@ -2198,9 +2235,6 @@
if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c);
}
- if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
- return EmitTypeConstructor(out, c);
- }
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i);
}
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index e63524f..7c56b10 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -44,6 +44,8 @@
namespace sem {
class Call;
class Intrinsic;
+class TypeConstructor;
+class TypeConversion;
} // namespace sem
namespace writer {
@@ -116,6 +118,38 @@
/// @param expr the call expression
/// @returns true if the call expression is emitted
bool EmitCall(std::ostream& out, const ast::CallExpression* expr);
+ /// Handles generating a function call expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param function the function being called
+ /// @returns true if the expression is emitted
+ bool EmitFunctionCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Function* function);
+ /// Handles generating an intrinsic call expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param intrinsic the intrinsic being called
+ /// @returns true if the expression is emitted
+ bool EmitIntrinsicCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Intrinsic* intrinsic);
+ /// Handles generating a type conversion expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param conv the type conversion
+ /// @returns true if the expression is emitted
+ bool EmitTypeConversion(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConversion* conv);
+ /// Handles generating a type constructor expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param ctor the type constructor
+ /// @returns true if the expression is emitted
+ bool EmitTypeConstructor(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConstructor* ctor);
/// Handles generating a call expression to a
/// transform::DecomposeMemoryAccess::Intrinsic for a uniform buffer
/// @param out the output of the expression stream
@@ -221,12 +255,6 @@
/// @param stmt the discard statement
/// @returns true if the statement was successfully emitted
bool EmitDiscard(const ast::DiscardStatement* stmt);
- /// Handles emitting a type constructor
- /// @param out the output of the expression stream
- /// @param expr the type constructor expression
- /// @returns true if the constructor is emitted
- bool EmitTypeConstructor(std::ostream& out,
- const ast::TypeConstructorExpression* expr);
/// Handles a continue statement
/// @param stmt the statement to emit
/// @returns true if the statement was emitted successfully
diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc
index 94e011d..f7ec04f 100644
--- a/src/writer/msl/generator_impl.cc
+++ b/src/writer/msl/generator_impl.cc
@@ -51,6 +51,8 @@
#include "src/sem/sampled_texture_type.h"
#include "src/sem/storage_texture_type.h"
#include "src/sem/struct.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
#include "src/sem/u32_type.h"
#include "src/sem/variable.h"
#include "src/sem/vector_type.h"
@@ -242,10 +244,9 @@
std::ostream& out,
const ast::IndexAccessorExpression* expr) {
bool paren_lhs =
- !expr->object
- ->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
- ast::IdentifierExpression, ast::MemberAccessorExpression,
- ast::TypeConstructorExpression>();
+ !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
+ ast::IdentifierExpression,
+ ast::MemberAccessorExpression>();
if (paren_lhs) {
out << "(";
@@ -496,43 +497,53 @@
bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) {
- auto* ident = expr->func;
auto* call = program_->Sem().Get(expr);
- if (auto* intrinsic = call->Target()->As<sem::Intrinsic>()) {
- return EmitIntrinsicCall(out, expr, intrinsic);
+ auto* target = call->Target();
+
+ if (auto* func = target->As<sem::Function>()) {
+ return EmitFunctionCall(out, call, func);
+ }
+ if (auto* intrinsic = target->As<sem::Intrinsic>()) {
+ return EmitIntrinsicCall(out, call, intrinsic);
+ }
+ if (auto* conv = target->As<sem::TypeConversion>()) {
+ return EmitTypeConversion(out, call, conv);
+ }
+ if (auto* ctor = target->As<sem::TypeConstructor>()) {
+ return EmitTypeConstructor(out, call, ctor);
}
- auto* func = program_->AST().Functions().Find(ident->symbol);
- if (func == nullptr) {
- diagnostics_.add_error(diag::System::Writer,
- "Unable to find function: " +
- program_->Symbols().NameFor(ident->symbol));
- return false;
- }
+ TINT_ICE(Writer, diagnostics_)
+ << "unhandled call target: " << target->TypeInfo().name;
+ return false;
+}
+bool GeneratorImpl::EmitFunctionCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Function*) {
+ auto* ident = call->Declaration()->target.name;
out << program_->Symbols().NameFor(ident->symbol) << "(";
bool first = true;
- const auto& args = expr->args;
- for (auto* arg : args) {
+ for (auto* arg : call->Arguments()) {
if (!first) {
out << ", ";
}
first = false;
- if (!EmitExpression(out, arg)) {
+ if (!EmitExpression(out, arg->Declaration())) {
return false;
}
}
out << ")";
-
return true;
}
bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
- const ast::CallExpression* expr,
+ const sem::Call* call,
const sem::Intrinsic* intrinsic) {
+ auto* expr = call->Declaration();
if (intrinsic->IsAtomic()) {
return EmitAtomicCall(out, expr, intrinsic);
}
@@ -634,6 +645,64 @@
return true;
}
+bool GeneratorImpl::EmitTypeConversion(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConversion* conv) {
+ if (!EmitType(out, conv->Target(), "")) {
+ return false;
+ }
+ out << "(";
+
+ if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
+ return false;
+ }
+
+ out << ")";
+ return true;
+}
+
+bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConstructor* ctor) {
+ auto* type = ctor->ReturnType();
+
+ if (type->IsAnyOf<sem::Array, sem::Struct>()) {
+ out << "{";
+ } else {
+ if (!EmitType(out, type, "")) {
+ return false;
+ }
+ out << "(";
+ }
+
+ int i = 0;
+ for (auto* arg : call->Arguments()) {
+ if (i > 0) {
+ out << ", ";
+ }
+
+ if (auto* struct_ty = type->As<sem::Struct>()) {
+ // Emit field designators for structures to account for padding members.
+ auto* member = struct_ty->Members()[i]->Declaration();
+ auto name = program_->Symbols().NameFor(member->symbol);
+ out << "." << name << "=";
+ }
+
+ if (!EmitExpression(out, arg->Declaration())) {
+ return false;
+ }
+
+ i++;
+ }
+
+ if (type->IsAnyOf<sem::Array, sem::Struct>()) {
+ out << "}";
+ } else {
+ out << ")";
+ }
+ return true;
+}
+
bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) {
@@ -762,10 +831,9 @@
// accessor used for the function calls.
auto texture_expr = [&]() {
bool paren_lhs =
- !texture
- ->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
- ast::IdentifierExpression, ast::MemberAccessorExpression,
- ast::TypeConstructorExpression>();
+ !texture->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
+ ast::IdentifierExpression,
+ ast::MemberAccessorExpression>();
if (paren_lhs) {
out << "(";
}
@@ -1300,48 +1368,6 @@
return true;
}
-bool GeneratorImpl::EmitTypeConstructor(
- std::ostream& out,
- const ast::TypeConstructorExpression* expr) {
- auto* type = TypeOf(expr)->UnwrapRef();
-
- if (type->IsAnyOf<sem::Array, sem::Struct>()) {
- out << "{";
- } else {
- if (!EmitType(out, type, "")) {
- return false;
- }
- out << "(";
- }
-
- int i = 0;
- for (auto* e : expr->values) {
- if (i > 0) {
- out << ", ";
- }
-
- if (auto* struct_ty = type->As<sem::Struct>()) {
- // Emit field designators for structures to account for padding members.
- auto* member = struct_ty->Members()[i]->Declaration();
- auto name = program_->Symbols().NameFor(member->symbol);
- out << "." << name << "=";
- }
-
- if (!EmitExpression(out, e)) {
- return false;
- }
-
- i++;
- }
-
- if (type->IsAnyOf<sem::Array, sem::Struct>()) {
- out << "}";
- } else {
- out << ")";
- }
- return true;
-}
-
bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
if (type->Is<sem::Bool>()) {
out << "false";
@@ -1426,9 +1452,6 @@
if (auto* c = expr->As<ast::CallExpression>()) {
return EmitCall(out, c);
}
- if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
- return EmitTypeConstructor(out, c);
- }
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return EmitIdentifier(out, i);
}
@@ -1899,11 +1922,9 @@
std::ostream& out,
const ast::MemberAccessorExpression* expr) {
auto write_lhs = [&] {
- bool paren_lhs =
- !expr->structure
- ->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
- ast::IdentifierExpression, ast::MemberAccessorExpression,
- ast::TypeConstructorExpression>();
+ bool paren_lhs = !expr->structure->IsAnyOf<
+ ast::IndexAccessorExpression, ast::CallExpression,
+ ast::IdentifierExpression, ast::MemberAccessorExpression>();
if (paren_lhs) {
out << "(";
}
diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h
index 40b16cc..8b73a31 100644
--- a/src/writer/msl/generator_impl.h
+++ b/src/writer/msl/generator_impl.h
@@ -33,7 +33,6 @@
#include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/switch_statement.h"
-#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h"
#include "src/program.h"
#include "src/scope_stack.h"
@@ -46,6 +45,8 @@
namespace sem {
class Call;
class Intrinsic;
+class TypeConstructor;
+class TypeConversion;
} // namespace sem
namespace writer {
@@ -130,12 +131,36 @@
bool EmitCall(std::ostream& out, const ast::CallExpression* expr);
/// Handles generating an intrinsic call expression
/// @param out the output of the expression stream
- /// @param expr the call expression
+ /// @param call the call expression
/// @param intrinsic the intrinsic being called
/// @returns true if the call expression is emitted
bool EmitIntrinsicCall(std::ostream& out,
- const ast::CallExpression* expr,
+ const sem::Call* call,
const sem::Intrinsic* intrinsic);
+ /// Handles generating a type conversion expression
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param conv the type conversion
+ /// @returns true if the expression is emitted
+ bool EmitTypeConversion(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConversion* conv);
+ /// Handles generating a type constructor
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param ctor the type constructor
+ /// @returns true if the constructor is emitted
+ bool EmitTypeConstructor(std::ostream& out,
+ const sem::Call* call,
+ const sem::TypeConstructor* ctor);
+ /// Handles generating a function call
+ /// @param out the output of the expression stream
+ /// @param call the call expression
+ /// @param func the target function
+ /// @returns true if the call is emitted
+ bool EmitFunctionCall(std::ostream& out,
+ const sem::Call* call,
+ const sem::Function* func);
/// Handles generating a call to an atomic function (`atomicAdd`,
/// `atomicMax`, etc)
/// @param out the output of the expression stream
@@ -293,12 +318,6 @@
/// @param str the struct to generate
/// @returns true if the struct is emitted
bool EmitStructType(TextBuffer* buffer, const sem::Struct* str);
- /// Handles emitting a type constructor
- /// @param out the output of the expression stream
- /// @param expr the type constructor expression
- /// @returns true if the constructor is emitted
- bool EmitTypeConstructor(std::ostream& out,
- const ast::TypeConstructorExpression* expr);
/// Handles a unary op expression
/// @param out the output of the expression stream
/// @param expr the expression to emit
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 697509d..5fbdfb7 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -22,6 +22,7 @@
#include "src/ast/fallthrough_statement.h"
#include "src/ast/internal_decoration.h"
#include "src/ast/override_decoration.h"
+#include "src/ast/traverse_expressions.h"
#include "src/sem/array.h"
#include "src/sem/atomic_type.h"
#include "src/sem/call.h"
@@ -33,7 +34,10 @@
#include "src/sem/multisampled_texture_type.h"
#include "src/sem/reference_type.h"
#include "src/sem/sampled_texture_type.h"
+#include "src/sem/statement.h"
#include "src/sem/struct.h"
+#include "src/sem/type_constructor.h"
+#include "src/sem/type_conversion.h"
#include "src/sem/variable.h"
#include "src/sem/vector_type.h"
#include "src/transform/add_empty_entry_point.h"
@@ -577,9 +581,6 @@
if (auto* c = expr->As<ast::CallExpression>()) {
return GenerateCallExpression(c);
}
- if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
- return GenerateConstructorExpression(nullptr, c);
- }
if (auto* i = expr->As<ast::IdentifierExpression>()) {
return GenerateIdentifierExpression(i);
}
@@ -1259,80 +1260,44 @@
if (auto* literal = expr->As<ast::LiteralExpression>()) {
return GenerateLiteralIfNeeded(var, literal);
}
- if (auto* type = expr->As<ast::TypeConstructorExpression>()) {
- return GenerateTypeConstructorExpression(var, type);
+ if (auto* call = builder_.Sem().Get<sem::Call>(expr)) {
+ if (call->Target()->IsAnyOf<sem::TypeConstructor, sem::TypeConversion>()) {
+ return GenerateTypeConstructorOrConversion(call, var);
+ }
}
-
error_ = "unknown constructor expression";
return 0;
}
-bool Builder::is_constructor_const(const ast::Expression* expr,
- bool is_global_init) {
- if (expr->Is<ast::LiteralExpression>()) {
- return true;
- }
+bool Builder::IsConstructorConst(const ast::Expression* expr) {
+ bool is_const = true;
+ ast::TraverseExpressions(expr, builder_.Diagnostics(),
+ [&](const ast::Expression* e) {
+ if (e->Is<ast::LiteralExpression>()) {
+ return ast::TraverseAction::Descend;
+ }
+ if (auto* ce = e->As<ast::CallExpression>()) {
+ auto* call = builder_.Sem().Get(ce);
+ if (call->Target()->Is<sem::TypeConstructor>()) {
+ return ast::TraverseAction::Descend;
+ }
+ }
- auto* tc = expr->As<ast::TypeConstructorExpression>();
- if (!tc) {
- return false;
- }
- auto* result_type = TypeOf(tc)->UnwrapRef();
- for (size_t i = 0; i < tc->values.size(); ++i) {
- auto* e = tc->values[i];
-
- if (!e->IsAnyOf<ast::TypeConstructorExpression, ast::LiteralExpression>()) {
- if (is_global_init) {
- error_ = "constructor must be a constant expression";
- return false;
- }
- return false;
- }
- if (!is_constructor_const(e, is_global_init)) {
- return false;
- }
- if (has_error()) {
- return false;
- }
-
- auto* lit = e->As<ast::LiteralExpression>();
- if (result_type->Is<sem::Vector>() && lit == nullptr) {
- return false;
- }
-
- // This should all be handled by |is_constructor_const| call above
- if (lit == nullptr) {
- continue;
- }
-
- const sem::Type* subtype = result_type->UnwrapRef();
- if (auto* vec = subtype->As<sem::Vector>()) {
- subtype = vec->type();
- } else if (auto* mat = subtype->As<sem::Matrix>()) {
- subtype = mat->type();
- } else if (auto* arr = subtype->As<sem::Array>()) {
- subtype = arr->ElemType();
- } else if (auto* str = subtype->As<sem::Struct>()) {
- subtype = str->Members()[i]->Type();
- }
- if (subtype != TypeOf(lit)->UnwrapRef()) {
- return false;
- }
- }
- return true;
+ is_const = false;
+ return ast::TraverseAction::Stop;
+ });
+ return is_const;
}
-uint32_t Builder::GenerateTypeConstructorExpression(
- const ast::Variable* var,
- const ast::TypeConstructorExpression* init) {
+uint32_t Builder::GenerateTypeConstructorOrConversion(
+ const sem::Call* call,
+ const ast::Variable* var) {
+ auto& args = call->Arguments();
auto* global_var = builder_.Sem().Get<sem::GlobalVariable>(var);
-
- auto& values = init->values;
-
- auto* result_type = TypeOf(init);
+ auto* result_type = call->Type();
// Generate the zero initializer if there are no values provided.
- if (values.empty()) {
+ if (args.empty()) {
if (global_var && global_var->IsOverridable()) {
auto constant_id = global_var->ConstantId();
if (result_type->Is<sem::I32>()) {
@@ -1356,10 +1321,10 @@
}
std::ostringstream out;
- out << "__const_" << init->type->FriendlyName(builder_.Symbols()) << "_";
+ out << "__const_" << result_type->FriendlyName(builder_.Symbols()) << "_";
result_type = result_type->UnwrapRef();
- bool constructor_is_const = is_constructor_const(init, global_var);
+ bool constructor_is_const = IsConstructorConst(call->Declaration());
if (has_error()) {
return 0;
}
@@ -1368,7 +1333,7 @@
if (auto* res_vec = result_type->As<sem::Vector>()) {
if (res_vec->type()->is_scalar()) {
- auto* value_type = TypeOf(values[0])->UnwrapRef();
+ auto* value_type = args[0]->Type()->UnwrapRef();
if (auto* val_vec = value_type->As<sem::Vector>()) {
if (val_vec->type()->is_scalar()) {
can_cast_or_copy = res_vec->Width() == val_vec->Width();
@@ -1378,7 +1343,8 @@
}
if (can_cast_or_copy) {
- return GenerateCastOrCopyOrPassthrough(result_type, values[0], global_var);
+ return GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
+ global_var);
}
auto type_id = GenerateTypeIfNeeded(result_type);
@@ -1394,19 +1360,18 @@
}
OperandList ops;
- for (auto* e : values) {
+ for (auto* e : args) {
uint32_t id = 0;
- if (constructor_is_const) {
- id = GenerateConstructorExpression(nullptr, e);
- } else {
- id = GenerateExpression(e);
- id = GenerateLoadIfNeeded(TypeOf(e), id);
+ id = GenerateExpression(e->Declaration());
+ if (id == 0) {
+ return 0;
}
+ id = GenerateLoadIfNeeded(e->Type(), id);
if (id == 0) {
return 0;
}
- auto* value_type = TypeOf(e)->UnwrapRef();
+ auto* value_type = e->Type()->UnwrapRef();
// If the result and value types are the same we can just use the object.
// If the result is not a vector then we should have validated that the
// value type is a correctly sized vector so we can just use it directly.
@@ -1421,7 +1386,8 @@
// Both scalars, but not the same type so we need to generate a conversion
// of the value.
if (value_type->is_scalar() && result_type->is_scalar()) {
- id = GenerateCastOrCopyOrPassthrough(result_type, values[0], global_var);
+ id = GenerateCastOrCopyOrPassthrough(result_type, args[0]->Declaration(),
+ global_var);
out << "_" << id;
ops.push_back(Operand::Int(id));
continue;
@@ -1483,9 +1449,9 @@
}
// For a single-value vector initializer, splat the initializer value.
- auto* const init_result_type = TypeOf(init)->UnwrapRef();
- if (values.size() == 1 && init_result_type->is_scalar_vector() &&
- TypeOf(values[0])->UnwrapRef()->is_scalar()) {
+ auto* const init_result_type = call->Type()->UnwrapRef();
+ if (args.size() == 1 && init_result_type->is_scalar_vector() &&
+ args[0]->Type()->UnwrapRef()->is_scalar()) {
size_t vec_size = init_result_type->As<sem::Vector>()->Width();
for (size_t i = 0; i < (vec_size - 1); ++i) {
ops.push_back(ops[0]);
@@ -2232,14 +2198,29 @@
}
uint32_t Builder::GenerateCallExpression(const ast::CallExpression* expr) {
- auto* ident = expr->func;
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
- if (auto* intrinsic = target->As<sem::Intrinsic>()) {
- return GenerateIntrinsic(expr, intrinsic);
- }
- auto type_id = GenerateTypeIfNeeded(target->ReturnType());
+ if (auto* func = target->As<sem::Function>()) {
+ return GenerateFunctionCall(call, func);
+ }
+ if (auto* intrinsic = target->As<sem::Intrinsic>()) {
+ return GenerateIntrinsicCall(call, intrinsic);
+ }
+ if (target->IsAnyOf<sem::TypeConversion, sem::TypeConstructor>()) {
+ return GenerateTypeConstructorOrConversion(call, nullptr);
+ }
+ TINT_ICE(Writer, builder_.Diagnostics())
+ << "unhandled call target: " << target->TypeInfo().name;
+ return false;
+}
+
+uint32_t Builder::GenerateFunctionCall(const sem::Call* call,
+ const sem::Function*) {
+ auto* expr = call->Declaration();
+ auto* ident = expr->target.name;
+
+ auto type_id = GenerateTypeIfNeeded(call->Type());
if (type_id == 0) {
return 0;
}
@@ -2278,8 +2259,8 @@
return result_id;
}
-uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call,
- const sem::Intrinsic* intrinsic) {
+uint32_t Builder::GenerateIntrinsicCall(const sem::Call* call,
+ const sem::Intrinsic* intrinsic) {
auto result = result_op();
auto result_id = result.to_i();
@@ -2323,15 +2304,15 @@
// and loads it if necessary. Returns 0 on error.
auto get_arg_as_value_id = [&](size_t i,
bool generate_load = true) -> uint32_t {
- auto* arg = call->args[i];
+ auto* arg = call->Arguments()[i];
auto* param = intrinsic->Parameters()[i];
- auto val_id = GenerateExpression(arg);
+ auto val_id = GenerateExpression(arg->Declaration());
if (val_id == 0) {
return 0;
}
if (generate_load && !param->Type()->Is<sem::Pointer>()) {
- val_id = GenerateLoadIfNeeded(TypeOf(arg), val_id);
+ val_id = GenerateLoadIfNeeded(arg->Type(), val_id);
}
return val_id;
};
@@ -2364,13 +2345,8 @@
op = spv::Op::OpAll;
break;
case IntrinsicType::kArrayLength: {
- if (call->args.empty()) {
- error_ = "missing param for runtime array length";
- return 0;
- }
- auto* arg = call->args[0];
-
- auto* address_of = arg->As<ast::UnaryOpExpression>();
+ auto* address_of =
+ call->Arguments()[0]->Declaration()->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
error_ = "arrayLength() expected pointer to member access, got " +
std::string(address_of->TypeInfo().name);
@@ -2695,7 +2671,7 @@
return 0;
}
- for (size_t i = 0; i < call->args.size(); i++) {
+ for (size_t i = 0; i < call->Arguments().size(); i++) {
if (auto val_id = get_arg_as_value_id(i)) {
params.emplace_back(Operand::Int(val_id));
} else {
@@ -2710,22 +2686,22 @@
return result_id;
}
-bool Builder::GenerateTextureIntrinsic(const ast::CallExpression* call,
+bool Builder::GenerateTextureIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic,
Operand result_type,
Operand result_id) {
using Usage = sem::ParameterUsage;
auto& signature = intrinsic->Signature();
- auto arguments = call->args;
+ auto& arguments = call->Arguments();
// Generates the given expression, returning the operand ID
- auto gen = [&](const ast::Expression* expr) {
- auto val_id = GenerateExpression(expr);
+ auto gen = [&](const sem::Expression* expr) {
+ auto val_id = GenerateExpression(expr->Declaration());
if (val_id == 0) {
return Operand::Int(0);
}
- val_id = GenerateLoadIfNeeded(TypeOf(expr), val_id);
+ val_id = GenerateLoadIfNeeded(expr->Type(), val_id);
return Operand::Int(val_id);
};
@@ -2751,7 +2727,7 @@
TINT_ICE(Writer, builder_.Diagnostics()) << "missing texture argument";
}
- auto* texture_type = TypeOf(texture)->UnwrapRef()->As<sem::Texture>();
+ auto* texture_type = texture->Type()->UnwrapRef()->As<sem::Texture>();
auto op = spv::Op::OpNop;
@@ -2819,7 +2795,7 @@
} else {
// Assign post_emission to swizzle the result of the call to
// OpImageQuerySize[Lod].
- auto* element_type = ElementTypeOf(TypeOf(call));
+ auto* element_type = ElementTypeOf(call->Type());
auto spirv_result = result_op();
auto* spirv_result_type =
builder_.create<sem::Vector>(element_type, spirv_result_width);
@@ -2856,8 +2832,9 @@
auto append_coords_to_spirv_params = [&]() -> bool {
if (auto* array_index = arg(Usage::kArrayIndex)) {
// Array index needs to be appended to the coordinates.
- auto* packed = AppendVector(&builder_, arg(Usage::kCoords), array_index);
- auto param = GenerateTypeConstructorExpression(nullptr, packed);
+ auto* packed = AppendVector(&builder_, arg(Usage::kCoords)->Declaration(),
+ array_index->Declaration());
+ auto param = GenerateExpression(packed->Declaration());
if (param == 0) {
return false;
}
@@ -3026,7 +3003,7 @@
return false;
}
auto level = Operand::Int(0);
- if (TypeOf(arg(Usage::kLevel))->Is<sem::I32>()) {
+ if (arg(Usage::kLevel)->Type()->UnwrapRef()->Is<sem::I32>()) {
// Depth textures have i32 parameters for the level, but SPIR-V expects
// F32. Cast.
auto f32_type_id = GenerateTypeIfNeeded(builder_.create<sem::F32>());
@@ -3156,7 +3133,7 @@
});
}
-bool Builder::GenerateAtomicIntrinsic(const ast::CallExpression* call,
+bool Builder::GenerateAtomicIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic,
Operand result_type,
Operand result_id) {
@@ -3193,18 +3170,18 @@
return false;
}
- uint32_t pointer_id = GenerateExpression(call->args[0]);
+ uint32_t pointer_id = GenerateExpression(call->Arguments()[0]->Declaration());
if (pointer_id == 0) {
return false;
}
uint32_t value_id = 0;
- if (call->args.size() > 1) {
- value_id = GenerateExpression(call->args.back());
+ if (call->Arguments().size() > 1) {
+ value_id = GenerateExpression(call->Arguments().back()->Declaration());
if (value_id == 0) {
return false;
}
- value_id = GenerateLoadIfNeeded(TypeOf(call->args.back()), value_id);
+ value_id = GenerateLoadIfNeeded(call->Arguments().back()->Type(), value_id);
if (value_id == 0) {
return false;
}
@@ -3308,12 +3285,12 @@
value,
});
case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
- auto comparator = GenerateExpression(call->args[1]);
+ auto comparator = GenerateExpression(call->Arguments()[1]->Declaration());
if (comparator == 0) {
return false;
}
- auto* value_sem_type = TypeOf(call->args[2]);
+ auto* value_sem_type = TypeOf(call->Arguments()[2]->Declaration());
auto value_type = GenerateTypeIfNeeded(value_sem_type);
if (value_type == 0) {
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 66e4f4a..d85e988 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -46,6 +46,8 @@
namespace sem {
class Call;
class Reference;
+class TypeConstructor;
+class TypeConversion;
} // namespace sem
namespace writer {
@@ -341,13 +343,6 @@
/// @returns the ID of the expression or 0 on failure.
uint32_t GenerateConstructorExpression(const ast::Variable* var,
const ast::Expression* expr);
- /// Generates a type constructor expression
- /// @param var the variable generated for, nullptr if no variable associated.
- /// @param init the expression to generate
- /// @returns the ID of the expression or 0 on failure.
- uint32_t GenerateTypeConstructorExpression(
- const ast::Variable* var,
- const ast::TypeConstructorExpression* init);
/// Generates a literal constant if needed
/// @param var the variable generated for, nullptr if no variable associated.
/// @param lit the literal to generate
@@ -371,12 +366,24 @@
/// @param expr the expression to generate
/// @returns the expression ID on success or 0 otherwise
uint32_t GenerateCallExpression(const ast::CallExpression* expr);
- /// Generates an intrinsic call
+ /// Handles generating a function call expression
/// @param call the call expression
- /// @param intrinsic the semantic information for the intrinsic
+ /// @param function the function being called
/// @returns the expression ID on success or 0 otherwise
- uint32_t GenerateIntrinsic(const ast::CallExpression* call,
- const sem::Intrinsic* intrinsic);
+ uint32_t GenerateFunctionCall(const sem::Call* call,
+ const sem::Function* function);
+ /// Handles generating an intrinsic call expression
+ /// @param call the call expression
+ /// @param intrinsic the intrinsic being called
+ /// @returns the expression ID on success or 0 otherwise
+ uint32_t GenerateIntrinsicCall(const sem::Call* call,
+ const sem::Intrinsic* intrinsic);
+ /// Handles generating a type constructor or type conversion expression
+ /// @param call the call expression
+ /// @param var the variable that is being initialized. May be null.
+ /// @returns the expression ID on success or 0 otherwise
+ uint32_t GenerateTypeConstructorOrConversion(const sem::Call* call,
+ const ast::Variable* var);
/// Generates a texture intrinsic call. Emits an error and returns false if
/// we're currently outside a function.
/// @param call the call expression
@@ -385,7 +392,7 @@
/// @param result_id result identifier operand of the texture instruction
/// parameters
/// @returns true on success
- bool GenerateTextureIntrinsic(const ast::CallExpression* call,
+ bool GenerateTextureIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic,
spirv::Operand result_type,
spirv::Operand result_id);
@@ -399,7 +406,7 @@
/// @param result_type result type operand of the texture instruction
/// @param result_id result identifier operand of the texture instruction
/// @returns true on success
- bool GenerateAtomicIntrinsic(const ast::CallExpression* call,
+ bool GenerateAtomicIntrinsic(const sem::Call* call,
const sem::Intrinsic* intrinsic,
Operand result_type,
Operand result_id);
@@ -536,9 +543,8 @@
/// Determines if the given type constructor is created from constant values
/// @param expr the expression to check
- /// @param is_global_init if this is a global initializer
/// @returns true if the constructor is constant
- bool is_constructor_const(const ast::Expression* expr, bool is_global_init);
+ bool IsConstructorConst(const ast::Expression* expr);
private:
/// @returns an Operand with a new result ID in it. Increments the next_id_
diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc
index 3488b2b..709cea4 100644
--- a/src/writer/spirv/builder_constructor_expression_test.cc
+++ b/src/writer/spirv/builder_constructor_expression_test.cc
@@ -165,20 +165,6 @@
)");
}
-TEST_F(SpvBuilderConstructorTest, Type_NonConst_Value_Fails) {
- auto* rel = create<ast::BinaryExpression>(ast::BinaryOp::kAdd, Expr(3.0f),
- Expr(3.0f));
-
- auto* t = vec2<f32>(1.0f, rel);
- auto* g = Global("g", ty.vec2<f32>(), t, ast::StorageClass::kPrivate);
-
- spirv::Builder& b = Build();
-
- EXPECT_EQ(b.GenerateConstructorExpression(g, t), 0u);
- EXPECT_TRUE(b.has_error());
- EXPECT_EQ(b.error(), R"(constructor must be a constant expression)");
-}
-
TEST_F(SpvBuilderConstructorTest, Type_Bool_With_Bool) {
auto* cast = Construct<bool>(true);
WrapInFunction(cast);
@@ -668,6 +654,36 @@
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"()");
}
+TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_F32_With_F32) {
+ auto* ctor = Construct<f32>(2.0f);
+ GlobalConst("g", ty.f32(), ctor);
+
+ spirv::Builder& b = SanitizeAndBuild();
+ ASSERT_TRUE(b.Build());
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
+%2 = OpConstant %1 2
+%4 = OpTypeVoid
+%3 = OpTypeFunction %4
+)");
+ Validate(b);
+}
+
+TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_U32_With_F32) {
+ auto* ctor = Construct<u32>(1.5f);
+ GlobalConst("g", ty.u32(), ctor);
+
+ spirv::Builder& b = SanitizeAndBuild();
+ ASSERT_TRUE(b.Build());
+
+ EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
+%2 = OpConstant %1 1
+%4 = OpTypeVoid
+%3 = OpTypeFunction %4
+)");
+ Validate(b);
+}
+
TEST_F(SpvBuilderConstructorTest, Type_ModuleScope_Vec2_With_F32) {
auto* cast = vec2<f32>(2.0f);
auto* g = Global("g", ty.vec2<f32>(), cast, ast::StorageClass::kPrivate);
@@ -1689,27 +1705,10 @@
spirv::Builder& b = Build();
- EXPECT_TRUE(b.is_constructor_const(t, true));
+ EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
-TEST_F(SpvBuilderConstructorTest, IsConstructorConst_GlobalVector_WithIdent) {
- // vec3<f32>(a, b, c) -> false -- ERROR
-
- Global("a", ty.f32(), ast::StorageClass::kPrivate);
- Global("b", ty.f32(), ast::StorageClass::kPrivate);
- Global("c", ty.f32(), ast::StorageClass::kPrivate);
-
- auto* t = vec3<f32>("a", "b", "c");
- WrapInFunction(t);
-
- spirv::Builder& b = Build();
-
- EXPECT_FALSE(b.is_constructor_const(t, true));
- EXPECT_TRUE(b.has_error());
- EXPECT_EQ(b.error(), "constructor must be a constant expression");
-}
-
TEST_F(SpvBuilderConstructorTest,
IsConstructorConst_GlobalArrayWithAllConstConstructors) {
// array<vec3<f32>, 2>(vec3<f32>(1.0, 2.0, 3.0), vec3<f32>(1.0, 2.0, 3.0))
@@ -1720,7 +1719,7 @@
spirv::Builder& b = Build();
- EXPECT_TRUE(b.is_constructor_const(t, true));
+ EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
@@ -1733,12 +1732,12 @@
spirv::Builder& b = Build();
- EXPECT_FALSE(b.is_constructor_const(t, true));
+ EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
TEST_F(SpvBuilderConstructorTest,
- IsConstructorConst_GlobalWithTypeCastConstructor) {
+ IsConstructorConst_GlobalWithTypeConversionConstructor) {
// vec2<f32>(f32(1), f32(2)) -> false
auto* t = vec2<f32>(Construct<f32>(1), Construct<f32>(2));
@@ -1746,7 +1745,7 @@
spirv::Builder& b = Build();
- EXPECT_FALSE(b.is_constructor_const(t, true));
+ EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
@@ -1759,7 +1758,7 @@
spirv::Builder& b = Build();
- EXPECT_TRUE(b.is_constructor_const(t, false));
+ EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
@@ -1775,7 +1774,7 @@
spirv::Builder& b = Build();
- EXPECT_FALSE(b.is_constructor_const(t, false));
+ EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
@@ -1792,12 +1791,12 @@
spirv::Builder& b = Build();
- EXPECT_TRUE(b.is_constructor_const(t, false));
+ EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
TEST_F(SpvBuilderConstructorTest,
- IsConstructorConst_VectorWithTypeCastConstConstructors) {
+ IsConstructorConst_VectorWithTypeConversionConstConstructors) {
// vec2<f32>(f32(1), f32(2)) -> false
auto* t = vec2<f32>(Construct<f32>(1), Construct<f32>(2));
@@ -1805,7 +1804,7 @@
spirv::Builder& b = Build();
- EXPECT_FALSE(b.is_constructor_const(t, false));
+ EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
@@ -1815,7 +1814,7 @@
spirv::Builder& b = Build();
- EXPECT_FALSE(b.is_constructor_const(t, false));
+ EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
@@ -1830,7 +1829,7 @@
spirv::Builder& b = Build();
- EXPECT_TRUE(b.is_constructor_const(t, false));
+ EXPECT_TRUE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
@@ -1849,7 +1848,7 @@
spirv::Builder& b = Build();
- EXPECT_FALSE(b.is_constructor_const(t, false));
+ EXPECT_FALSE(b.IsConstructorConst(t));
EXPECT_FALSE(b.has_error());
}
diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc
index b1fe63b..29ea3ef 100644
--- a/src/writer/wgsl/generator_impl.cc
+++ b/src/writer/wgsl/generator_impl.cc
@@ -134,9 +134,6 @@
if (auto* l = expr->As<ast::LiteralExpression>()) {
return EmitLiteral(out, l);
}
- if (auto* c = expr->As<ast::TypeConstructorExpression>()) {
- return EmitTypeConstructor(out, c);
- }
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
return EmitMemberAccessor(out, m);
}
@@ -156,10 +153,9 @@
std::ostream& out,
const ast::IndexAccessorExpression* expr) {
bool paren_lhs =
- !expr->object
- ->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
- ast::IdentifierExpression, ast::MemberAccessorExpression,
- ast::TypeConstructorExpression>();
+ !expr->object->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
+ ast::IdentifierExpression,
+ ast::MemberAccessorExpression>();
if (paren_lhs) {
out << "(";
}
@@ -183,10 +179,9 @@
std::ostream& out,
const ast::MemberAccessorExpression* expr) {
bool paren_lhs =
- !expr->structure
- ->IsAnyOf<ast::IndexAccessorExpression, ast::CallExpression,
- ast::IdentifierExpression, ast::MemberAccessorExpression,
- ast::TypeConstructorExpression>();
+ !expr->structure->IsAnyOf<ast::IndexAccessorExpression,
+ ast::CallExpression, ast::IdentifierExpression,
+ ast::MemberAccessorExpression>();
if (paren_lhs) {
out << "(";
}
@@ -220,7 +215,17 @@
bool GeneratorImpl::EmitCall(std::ostream& out,
const ast::CallExpression* expr) {
- if (!EmitExpression(out, expr->func)) {
+ if (expr->target.name) {
+ if (!EmitExpression(out, expr->target.name)) {
+ return false;
+ }
+ } else if (expr->target.type) {
+ if (!EmitType(out, expr->target.type)) {
+ return false;
+ }
+ } else {
+ TINT_ICE(Writer, diagnostics_)
+ << "CallExpression target had neither a name or type";
return false;
}
out << "(";
@@ -243,31 +248,6 @@
return true;
}
-bool GeneratorImpl::EmitTypeConstructor(
- std::ostream& out,
- const ast::TypeConstructorExpression* expr) {
- if (!EmitType(out, expr->type)) {
- return false;
- }
-
- out << "(";
-
- bool first = true;
- for (auto* e : expr->values) {
- if (!first) {
- out << ", ";
- }
- first = false;
-
- if (!EmitExpression(out, e)) {
- return false;
- }
- }
-
- out << ")";
- return true;
-}
-
bool GeneratorImpl::EmitLiteral(std::ostream& out,
const ast::LiteralExpression* lit) {
if (auto* bl = lit->As<ast::BoolLiteralExpression>()) {
diff --git a/src/writer/wgsl/generator_impl.h b/src/writer/wgsl/generator_impl.h
index 7f4a68c..a88f263 100644
--- a/src/writer/wgsl/generator_impl.h
+++ b/src/writer/wgsl/generator_impl.h
@@ -31,7 +31,6 @@
#include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/switch_statement.h"
-#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h"
#include "src/program.h"
#include "src/sem/storage_texture_type.h"
@@ -183,12 +182,6 @@
/// @param access the access to generate
/// @returns true if the access is emitted
bool EmitAccess(std::ostream& out, const ast::Access access);
- /// Handles emitting a type constructor
- /// @param out the output of the expression stream
- /// @param expr the type constructor expression
- /// @returns true if the constructor is emitted
- bool EmitTypeConstructor(std::ostream& out,
- const ast::TypeConstructorExpression* expr);
/// Handles a unary op expression
/// @param out the output of the expression stream
/// @param expr the expression to emit
diff --git a/test/BUILD.gn b/test/BUILD.gn
index 6325c34..2a1883c 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -203,7 +203,6 @@
"../src/ast/test_helper.h",
"../src/ast/texture_test.cc",
"../src/ast/traverse_expressions_test.cc",
- "../src/ast/type_constructor_expression_test.cc",
"../src/ast/u32_test.cc",
"../src/ast/uint_literal_expression_test.cc",
"../src/ast/unary_op_expression_test.cc",