blob: ea3da7c553b17f4ff9b5a623e7ca606b097f6907 [file] [log] [blame]
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/writer/append_vector.h"
#include <utility>
#include "src/sem/expression.h"
namespace tint {
namespace writer {
namespace {
ast::TypeConstructorExpression* AsVectorConstructor(ProgramBuilder* b,
ast::Expression* expr) {
if (auto* constructor = expr->As<ast::TypeConstructorExpression>()) {
if (b->TypeOf(constructor)->Is<sem::Vector>()) {
return constructor;
}
}
return nullptr;
}
} // namespace
ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b,
ast::Expression* vector,
ast::Expression* scalar) {
uint32_t packed_size;
const sem::Type* packed_el_sem_ty;
auto* vector_sem = b->Sem().Get(vector);
auto* vector_ty = vector_sem->Type()->UnwrapRef();
if (auto* vec = vector_ty->As<sem::Vector>()) {
packed_size = vec->Width() + 1;
packed_el_sem_ty = vec->type();
} else {
packed_size = 2;
packed_el_sem_ty = vector_ty;
}
ast::Type* packed_el_ty = nullptr;
if (packed_el_sem_ty->Is<sem::I32>()) {
packed_el_ty = b->create<ast::I32>();
} else if (packed_el_sem_ty->Is<sem::U32>()) {
packed_el_ty = b->create<ast::U32>();
} else if (packed_el_sem_ty->Is<sem::F32>()) {
packed_el_ty = b->create<ast::F32>();
} else if (packed_el_sem_ty->Is<sem::Bool>()) {
packed_el_ty = b->create<ast::Bool>();
} else {
TINT_UNREACHABLE(Writer, b->Diagnostics())
<< "unsupported vector element type: "
<< packed_el_sem_ty->TypeInfo().name;
}
auto* statement = vector_sem->Stmt();
auto* packed_ty = b->create<ast::Vector>(packed_el_ty, packed_size);
auto* packed_sem_ty = b->create<sem::Vector>(packed_el_sem_ty, packed_size);
// If the coordinates are already passed in a vector constructor, with only
// scalar components supplied, extract the elements into the new vector
// instead of nesting a vector-in-vector.
// If the coordinates are a zero-constructor of the vector, then expand that
// to scalar zeros.
// The other cases for a nested vector constructor are when it is used
// to convert a vector of a different type, e.g. vec2<i32>(vec2<u32>()).
// In that case, preserve the original argument, or you'll get a type error.
ast::ExpressionList packed;
if (auto* vc = AsVectorConstructor(b, vector)) {
const auto num_supplied = vc->values().size();
if (num_supplied == 0) {
// Zero-value vector constructor. Populate with zeros
auto buildZero = [&]() -> ast::ScalarConstructorExpression* {
if (packed_el_sem_ty->Is<sem::I32>()) {
return b->Expr(0);
} else if (packed_el_sem_ty->Is<sem::U32>()) {
return b->Expr(0u);
} else if (packed_el_sem_ty->Is<sem::F32>()) {
return b->Expr(0.0f);
} else if (packed_el_sem_ty->Is<sem::Bool>()) {
return b->Expr(false);
} else {
TINT_UNREACHABLE(Writer, b->Diagnostics())
<< "unsupported vector element type: "
<< packed_el_sem_ty->TypeInfo().name;
}
return nullptr;
};
for (uint32_t i = 0; i < packed_size - 1; i++) {
auto* zero = buildZero();
b->Sem().Add(
zero, b->create<sem::Expression>(zero, packed_el_sem_ty, statement,
sem::Constant{}));
packed.emplace_back(zero);
}
} else if (num_supplied + 1 == packed_size) {
// All vector components were supplied as scalars. Pass them through.
packed = vc->values();
}
}
if (packed.empty()) {
// The special cases didn't occur. Use the vector argument as-is.
packed.emplace_back(vector);
}
if (packed_el_sem_ty != b->TypeOf(scalar)->UnwrapRef()) {
// Cast scalar to the vector element type
auto* scalar_cast = b->Construct(packed_el_ty, scalar);
b->Sem().Add(scalar_cast,
b->create<sem::Expression>(scalar_cast, packed_el_sem_ty,
statement, sem::Constant{}));
packed.emplace_back(scalar_cast);
} else {
packed.emplace_back(scalar);
}
auto* constructor = b->Construct(packed_ty, std::move(packed));
b->Sem().Add(constructor,
b->create<sem::Expression>(constructor, packed_sem_ty, statement,
sem::Constant{}));
return constructor;
}
} // namespace writer
} // namespace tint