// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <utility>

#include "gmock/gmock.h"
#include "src/tint/ir/builder.h"
#include "src/tint/ir/ir_test_helper.h"
#include "src/tint/ir/validate.h"
#include "src/tint/type/matrix.h"
#include "src/tint/type/pointer.h"
#include "src/tint/type/struct.h"

namespace tint::ir {
namespace {

using namespace tint::number_suffixes;  // NOLINT

using IR_ValidateTest = IRTestHelper;

TEST_F(IR_ValidateTest, RootBlock_Var) {
    mod.root_block = b.RootBlock();
    mod.root_block->Append(
        b.Var(ty.ptr(builtin::AddressSpace::kPrivate, ty.i32(), builtin::Access::kReadWrite)));
    auto res = ir::Validate(mod);
    EXPECT_TRUE(res) << res.Failure().str();
}

TEST_F(IR_ValidateTest, RootBlock_NonVar) {
    auto* l = b.Loop();
    l->Body()->Append(b.Continue(l));

    mod.root_block = b.RootBlock();
    mod.root_block->Append(l);

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:3:3 error: root block: invalid instruction: tint::ir::Loop
  loop [b: %b2]
  ^^^^^^^^^^^^^

:2:1 note: In block
%b1 = block {
^^^^^^^^^^^

note: # Disassembly
# Root block
%b1 = block {
  loop [b: %b2]
    # Body block
    %b2 = block {
      continue %b3
    }

}

)");
}

TEST_F(IR_ValidateTest, Function) {
    auto* f = b.Function("my_func", ty.void_());
    mod.functions.Push(f);

    f->SetParams({b.FunctionParam(ty.i32()), b.FunctionParam(ty.f32())});
    f->StartTarget()->SetInstructions({b.Return(f)});
    auto res = ir::Validate(mod);
    EXPECT_TRUE(res) << res.Failure().str();
}

TEST_F(IR_ValidateTest, Block_NoBranchAtEnd) {
    auto* f = b.Function("my_func", ty.void_());
    mod.functions.Push(f);

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:2:3 error: block: does not end in a branch
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func():void -> %b1 {
  %b1 = block {
  }
}
)");
}

TEST_F(IR_ValidateTest, Valid_Access_Value) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(ty.mat3x2<f32>());
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u, 0_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    EXPECT_TRUE(res) << res.Failure().str();
}

TEST_F(IR_ValidateTest, Valid_Access_Ptr) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(
        ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.ptr<private_, f32>(), obj, 1_u, 0_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    EXPECT_TRUE(res) << res.Failure().str();
}

TEST_F(IR_ValidateTest, Access_NegativeIndex) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(ty.vec3<f32>());
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.f32(), obj, -1_i));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:3:25 error: access: constant index must be positive, got -1
    %3:f32 = access %2, -1i
                        ^^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func(%2:vec3<f32>):void -> %b1 {
  %b1 = block {
    %3:f32 = access %2, -1i
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_OOB_Index_Value) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(ty.mat3x2<f32>());
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u, 3_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:3:29 error: access: index out of bounds for type vec2<f32>
    %3:f32 = access %2, 1u, 3u
                            ^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

:3:29 note: acceptable range: [0..1]
    %3:f32 = access %2, 1u, 3u
                            ^^

note: # Disassembly
%my_func = func(%2:mat3x2<f32>):void -> %b1 {
  %b1 = block {
    %3:f32 = access %2, 1u, 3u
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_OOB_Index_Ptr) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(
        ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.ptr<private_, f32>(), obj, 1_u, 3_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(),
              R"(:3:55 error: access: index out of bounds for type ptr<vec2<f32>>
    %3:ptr<private, f32, read_write> = access %2, 1u, 3u
                                                      ^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

:3:55 note: acceptable range: [0..1]
    %3:ptr<private, f32, read_write> = access %2, 1u, 3u
                                                      ^^

note: # Disassembly
%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
  %b1 = block {
    %3:ptr<private, f32, read_write> = access %2, 1u, 3u
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_StaticallyUnindexableType_Value) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(ty.f32());
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:3:25 error: access: type f32 cannot be indexed
    %3:f32 = access %2, 1u
                        ^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func(%2:f32):void -> %b1 {
  %b1 = block {
    %3:f32 = access %2, 1u
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_StaticallyUnindexableType_Ptr) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(ty.ptr<private_, f32>());
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.ptr<private_, f32>(), obj, 1_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:3:51 error: access: type ptr<f32> cannot be indexed
    %3:ptr<private, f32, read_write> = access %2, 1u
                                                  ^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func(%2:ptr<private, f32, read_write>):void -> %b1 {
  %b1 = block {
    %3:ptr<private, f32, read_write> = access %2, 1u
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_DynamicallyUnindexableType_Value) {
    utils::Vector members{
        ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 0u, 0u, 4u, 4u,
                                   type::StructMemberAttributes{}),
        ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 1u, 4u, 4u, 4u,
                                   type::StructMemberAttributes{}),
    };
    auto* str_ty = ty.Get<type::Struct>(mod.symbols.New(), std::move(members), 4u, 8u, 8u);

    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(str_ty);
    auto* idx = b.FunctionParam(ty.i32());
    f->SetParams({obj, idx});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.i32(), obj, idx));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(),
              R"(:8:25 error: access: type tint_symbol_2 cannot be dynamically indexed
    %4:i32 = access %2, %3
                        ^^

:7:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
tint_symbol_2 = struct @align(4) {
  tint_symbol:i32 @offset(0)
  tint_symbol_1:i32 @offset(4)
}

%my_func = func(%2:tint_symbol_2, %3:i32):void -> %b1 {
  %b1 = block {
    %4:i32 = access %2, %3
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_DynamicallyUnindexableType_Ptr) {
    utils::Vector members{
        ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 0u, 0u, 4u, 4u,
                                   type::StructMemberAttributes{}),
        ty.Get<type::StructMember>(mod.symbols.New(), ty.i32(), 1u, 4u, 4u, 4u,
                                   type::StructMemberAttributes{}),
    };
    auto* str_ty = ty.Get<type::Struct>(mod.symbols.New(), std::move(members), 4u, 8u, 8u);

    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(
        ty.ptr(builtin::AddressSpace::kPrivate, str_ty, builtin::Access::kReadWrite));
    auto* idx = b.FunctionParam(ty.i32());
    f->SetParams({obj, idx});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.i32(), obj, idx));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(),
              R"(:8:25 error: access: type ptr<tint_symbol_2> cannot be dynamically indexed
    %4:i32 = access %2, %3
                        ^^

:7:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
tint_symbol_2 = struct @align(4) {
  tint_symbol:i32 @offset(0)
  tint_symbol_1:i32 @offset(4)
}

%my_func = func(%2:ptr<private, tint_symbol_2, read_write>, %3:i32):void -> %b1 {
  %b1 = block {
    %4:i32 = access %2, %3
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_Incorrect_Type_Value_Value) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(ty.mat3x2<f32>());
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.i32(), obj, 1_u, 1_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(),
              R"(:3:14 error: access: result of access chain is type f32 but instruction type is i32
    %3:i32 = access %2, 1u, 1u
             ^^^^^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func(%2:mat3x2<f32>):void -> %b1 {
  %b1 = block {
    %3:i32 = access %2, 1u, 1u
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_Incorrect_Type_Ptr_Ptr) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(
        ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.ptr<private_, i32>(), obj, 1_u, 1_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(
        res.Failure().str(),
        R"(:3:40 error: access: result of access chain is type ptr<f32> but instruction type is ptr<i32>
    %3:ptr<private, i32, read_write> = access %2, 1u, 1u
                                       ^^^^^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
  %b1 = block {
    %3:ptr<private, i32, read_write> = access %2, 1u, 1u
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Access_Incorrect_Type_Ptr_Value) {
    auto* f = b.Function("my_func", ty.void_());
    auto* obj = b.FunctionParam(
        ty.ptr(builtin::AddressSpace::kPrivate, ty.mat3x2<f32>(), builtin::Access::kReadWrite));
    f->SetParams({obj});
    mod.functions.Push(f);

    f->StartTarget()->Append(b.Access(ty.f32(), obj, 1_u, 1_u));
    f->StartTarget()->Append(b.Return(f));

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(
        res.Failure().str(),
        R"(:3:14 error: access: result of access chain is type ptr<f32> but instruction type is f32
    %3:f32 = access %2, 1u, 1u
             ^^^^^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func(%2:ptr<private, mat3x2<f32>, read_write>):void -> %b1 {
  %b1 = block {
    %3:f32 = access %2, 1u, 1u
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, Block_BranchInMiddle) {
    auto* f = b.Function("my_func", ty.void_());
    mod.functions.Push(f);

    f->StartTarget()->SetInstructions({b.Return(f), b.Return(f)});
    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:3:5 error: block: branch which isn't the final instruction
    ret
    ^^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func():void -> %b1 {
  %b1 = block {
    ret
    ret
  }
}
)");
}

TEST_F(IR_ValidateTest, If_ConditionIsBool) {
    auto* f = b.Function("my_func", ty.void_());
    mod.functions.Push(f);

    auto* if_ = b.If(1_i);
    if_->True()->Append(b.Return(f));
    if_->False()->Append(b.Return(f));

    f->StartTarget()->Append(if_);

    auto res = ir::Validate(mod);
    ASSERT_FALSE(res);
    EXPECT_EQ(res.Failure().str(), R"(:3:8 error: if: condition must be a `bool` type
    if 1i [t: %b2, f: %b3]
       ^^

:2:3 note: In block
  %b1 = block {
  ^^^^^^^^^^^

note: # Disassembly
%my_func = func():void -> %b1 {
  %b1 = block {
    if 1i [t: %b2, f: %b3]
      # True block
      %b2 = block {
        ret
      }

      # False block
      %b3 = block {
        ret
      }

  }
}
)");
}

}  // namespace
}  // namespace tint::ir
