// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// test-runner runs tint against a number of test shaders checking for expected behavior
package main

import (
	"context"
	"flag"
	"fmt"
	"io/ioutil"
	"os"
	"os/exec"
	"path/filepath"
	"runtime"
	"sort"
	"strings"
	"time"
	"unicode/utf8"

	"dawn.googlesource.com/tint/tools/src/fileutils"
	"dawn.googlesource.com/tint/tools/src/glob"
	"github.com/fatih/color"
	"github.com/sergi/go-diff/diffmatchpatch"
)

type outputFormat string

const (
	testTimeout = 30 * time.Second

	wgsl   = outputFormat("wgsl")
	spvasm = outputFormat("spvasm")
	msl    = outputFormat("msl")
	hlsl   = outputFormat("hlsl")
)

func main() {
	if err := run(); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
}

func showUsage() {
	fmt.Println(`
test-runner runs tint against a number of test shaders checking for expected behavior

usage:
  test-runner [flags...] <executable> [<directory>]

  <executable> the path to the tint executable
  <directory>  the root directory of the test files

optional flags:`)
	flag.PrintDefaults()
	fmt.Println(``)
	os.Exit(1)
}

func run() error {
	var formatList, filter, dxcPath, xcrunPath string
	var maxFilenameColumnWidth int
	numCPU := runtime.NumCPU()
	fxc, verbose, generateExpected, generateSkip := false, false, false, false
	flag.StringVar(&formatList, "format", "all", "comma separated list of formats to emit. Possible values are: all, wgsl, spvasm, msl, hlsl")
	flag.StringVar(&filter, "filter", "**.wgsl, **.spvasm, **.spv", "comma separated list of glob patterns for test files")
	flag.StringVar(&dxcPath, "dxc", "", "path to DXC executable for validating HLSL output")
	flag.StringVar(&xcrunPath, "xcrun", "", "path to xcrun executable for validating MSL output")
	flag.BoolVar(&fxc, "fxc", false, "validate with FXC instead of DXC")
	flag.BoolVar(&verbose, "verbose", false, "print all run tests, including rows that all pass")
	flag.BoolVar(&generateExpected, "generate-expected", false, "create or update all expected outputs")
	flag.BoolVar(&generateSkip, "generate-skip", false, "create or update all expected outputs that fail with SKIP")
	flag.IntVar(&numCPU, "j", numCPU, "maximum number of concurrent threads to run tests")
	flag.IntVar(&maxFilenameColumnWidth, "filename-column-width", 0, "maximum width of the filename column")
	flag.Usage = showUsage
	flag.Parse()

	args := flag.Args()
	if len(args) == 0 {
		showUsage()
	}

	// executable path is the first argument
	exe, args := args[0], args[1:]

	// (optional) target directory is the second argument
	dir := "."
	if len(args) > 0 {
		dir, args = args[0], args[1:]
	}

	// Check the executable can be found and actually is executable
	if !fileutils.IsExe(exe) {
		return fmt.Errorf("'%s' not found or is not executable", exe)
	}
	exe, err := filepath.Abs(exe)
	if err != nil {
		return err
	}

	// Allow using '/' in the filter on Windows
	filter = strings.ReplaceAll(filter, "/", string(filepath.Separator))

	// Split the --filter flag up by ',', trimming any whitespace at the start and end
	globIncludes := strings.Split(filter, ",")
	for i, s := range globIncludes {
		// Escape backslashes for the glob config
		s = strings.ReplaceAll(s, `\`, `\\`)
		globIncludes[i] = `"` + strings.TrimSpace(s) + `"`
	}

	// Glob the files to test
	files, err := glob.Scan(dir, glob.MustParseConfig(`{
		"paths": [
			{
				"include": [ `+strings.Join(globIncludes, ",")+` ]
			},
			{
				"exclude": [
					"**.expected.wgsl",
					"**.expected.spvasm",
					"**.expected.msl",
					"**.expected.hlsl"
				]
			}
		]
	}`))
	if err != nil {
		return fmt.Errorf("Failed to glob files: %w", err)
	}

	// Ensure the files are sorted (globbing should do this, but why not)
	sort.Strings(files)

	// Parse --format into a list of outputFormat
	formats := []outputFormat{}
	if formatList == "all" {
		formats = []outputFormat{wgsl, spvasm, msl, hlsl}
	} else {
		for _, f := range strings.Split(formatList, ",") {
			switch strings.TrimSpace(f) {
			case "wgsl":
				formats = append(formats, wgsl)
			case "spvasm":
				formats = append(formats, spvasm)
			case "msl":
				formats = append(formats, msl)
			case "hlsl":
				formats = append(formats, hlsl)
			default:
				return fmt.Errorf("unknown format '%s'", f)
			}
		}
	}

	defaultMSLExe := "xcrun"
	if runtime.GOOS == "windows" {
		defaultMSLExe = "metal.exe"
	}

	// If explicit verification compilers have been specified, check they exist.
	// Otherwise, look on PATH for them, but don't error if they cannot be found.
	for _, tool := range []struct {
		name string
		lang string
		path *string
	}{
		{"dxc", "hlsl", &dxcPath},
		{defaultMSLExe, "msl", &xcrunPath},
	} {
		if *tool.path == "" {
			p, err := exec.LookPath(tool.name)
			if err == nil && fileutils.IsExe(p) {
				*tool.path = p
			}
		} else if !fileutils.IsExe(*tool.path) {
			return fmt.Errorf("%v not found at '%v'", tool.name, *tool.path)
		}

		color.Set(color.FgCyan)
		fmt.Printf("%-4s", tool.lang)
		color.Unset()
		fmt.Printf(" validation ")
		if *tool.path != "" || (fxc && tool.lang == "hlsl") {
			color.Set(color.FgGreen)
			fmt.Printf("ENABLED")
		} else {
			color.Set(color.FgRed)
			fmt.Printf("DISABLED")
		}
		color.Unset()
		fmt.Println()
	}
	fmt.Println()

	// Build the list of results.
	// These hold the chans used to report the job results.
	results := make([]map[outputFormat]chan status, len(files))
	for i := range files {
		fileResults := map[outputFormat]chan status{}
		for _, format := range formats {
			fileResults[format] = make(chan status, 1)
		}
		results[i] = fileResults
	}

	pendingJobs := make(chan job, 256)

	// Spawn numCPU job runners...
	for cpu := 0; cpu < numCPU; cpu++ {
		go func() {
			for job := range pendingJobs {
				job.run(dir, exe, fxc, dxcPath, xcrunPath, generateExpected, generateSkip)
			}
		}()
	}

	// Issue the jobs...
	go func() {
		for i, file := range files { // For each test file...
			file := filepath.Join(dir, file)
			for _, format := range formats { // For each output format...
				pendingJobs <- job{
					file:   file,
					format: format,
					result: results[i][format],
				}
			}
		}
		close(pendingJobs)
	}()

	type failure struct {
		file   string
		format outputFormat
		err    error
	}

	type stats struct {
		numTests, numPass, numSkip, numFail int
	}

	// Statistics per output format
	statsByFmt := map[outputFormat]*stats{}
	for _, format := range formats {
		statsByFmt[format] = &stats{}

	}

	// Print the table of file x format and gather per-format stats
	failures := []failure{}
	filenameColumnWidth := maxStringLen(files)
	if maxFilenameColumnWidth > 0 {
		filenameColumnWidth = maxFilenameColumnWidth
	}

	red := color.New(color.FgRed)
	green := color.New(color.FgGreen)
	yellow := color.New(color.FgYellow)
	cyan := color.New(color.FgCyan)

	printFormatsHeader := func() {
		fmt.Printf(strings.Repeat(" ", filenameColumnWidth))
		fmt.Printf(" ┃ ")
		for _, format := range formats {
			cyan.Printf(alignCenter(format, formatWidth(format)))
			fmt.Printf(" │ ")
		}
		fmt.Println()
	}
	printHorizontalLine := func() {
		fmt.Printf(strings.Repeat("━", filenameColumnWidth))
		fmt.Printf("━╋━")
		for _, format := range formats {
			fmt.Printf(strings.Repeat("━", formatWidth(format)))
			fmt.Printf("━┿━")
		}
		fmt.Println()
	}

	fmt.Println()

	printFormatsHeader()
	printHorizontalLine()

	for i, file := range files {
		results := results[i]

		row := &strings.Builder{}
		rowAllPassed := true

		filenameLength := utf8.RuneCountInString(file)
		shortFile := file
		if filenameLength > filenameColumnWidth {
			shortFile = "..." + file[filenameLength-filenameColumnWidth+3:]
		}

		fmt.Fprintf(row, alignRight(shortFile, filenameColumnWidth))
		fmt.Fprintf(row, " ┃ ")
		for _, format := range formats {
			columnWidth := formatWidth(format)
			result := <-results[format]
			stats := statsByFmt[format]
			stats.numTests++
			if err := result.err; err != nil {
				failures = append(failures, failure{
					file: file, format: format, err: err,
				})
			}
			switch result.code {
			case pass:
				green.Fprintf(row, alignCenter("PASS", columnWidth))
				stats.numPass++
			case fail:
				red.Fprintf(row, alignCenter("FAIL", columnWidth))
				rowAllPassed = false
				stats.numFail++
			case skip:
				yellow.Fprintf(row, alignCenter("SKIP", columnWidth))
				rowAllPassed = false
				stats.numSkip++
			default:
				fmt.Fprintf(row, alignCenter(result.code, columnWidth))
				rowAllPassed = false
			}
			fmt.Fprintf(row, " │ ")
		}

		if verbose || !rowAllPassed {
			fmt.Fprintln(color.Output, row)
		}
	}

	printHorizontalLine()
	printFormatsHeader()
	printHorizontalLine()
	printStat := func(col *color.Color, name string, num func(*stats) int) {
		row := &strings.Builder{}
		anyNonZero := false
		for _, format := range formats {
			columnWidth := formatWidth(format)
			count := num(statsByFmt[format])
			if count > 0 {
				col.Fprintf(row, alignLeft(count, columnWidth))
				anyNonZero = true
			} else {
				fmt.Fprintf(row, alignLeft(count, columnWidth))
			}
			fmt.Fprintf(row, " │ ")
		}

		if !anyNonZero {
			return
		}
		col.Printf(alignRight(name, filenameColumnWidth))
		fmt.Printf(" ┃ ")
		fmt.Fprintln(color.Output, row)

		col.Printf(strings.Repeat(" ", filenameColumnWidth))
		fmt.Printf(" ┃ ")
		for _, format := range formats {
			columnWidth := formatWidth(format)
			stats := statsByFmt[format]
			count := num(stats)
			percent := percentage(count, stats.numTests)
			if count > 0 {
				col.Print(alignRight(percent, columnWidth))
			} else {
				fmt.Print(alignRight(percent, columnWidth))
			}
			fmt.Printf(" │ ")
		}
		fmt.Println()
	}
	printStat(green, "PASS", func(s *stats) int { return s.numPass })
	printStat(yellow, "SKIP", func(s *stats) int { return s.numSkip })
	printStat(red, "FAIL", func(s *stats) int { return s.numFail })
	fmt.Println()

	for _, f := range failures {
		color.Set(color.FgBlue)
		fmt.Printf("%s ", f.file)
		color.Set(color.FgCyan)
		fmt.Printf("%s ", f.format)
		color.Set(color.FgRed)
		fmt.Println("FAIL")
		color.Unset()
		fmt.Println(indent(f.err.Error(), 4))
	}
	if len(failures) > 0 {
		fmt.Println()
	}

	allStats := stats{}
	for _, format := range formats {
		stats := statsByFmt[format]
		allStats.numTests += stats.numTests
		allStats.numPass += stats.numPass
		allStats.numSkip += stats.numSkip
		allStats.numFail += stats.numFail
	}

	fmt.Printf("%d tests run", allStats.numTests)
	if allStats.numPass > 0 {
		fmt.Printf(", ")
		color.Set(color.FgGreen)
		fmt.Printf("%d tests pass", allStats.numPass)
		color.Unset()
	} else {
		fmt.Printf(", %d tests pass", allStats.numPass)
	}
	if allStats.numSkip > 0 {
		fmt.Printf(", ")
		color.Set(color.FgYellow)
		fmt.Printf("%d tests skipped", allStats.numSkip)
		color.Unset()
	} else {
		fmt.Printf(", %d tests skipped", allStats.numSkip)
	}
	if allStats.numFail > 0 {
		fmt.Printf(", ")
		color.Set(color.FgRed)
		fmt.Printf("%d tests failed", allStats.numFail)
		color.Unset()
	} else {
		fmt.Printf(", %d tests failed", allStats.numFail)
	}
	fmt.Println()
	fmt.Println()

	if allStats.numFail > 0 {
		os.Exit(1)
	}

	return nil
}

// Structures to hold the results of the tests
type statusCode string

const (
	fail statusCode = "FAIL"
	pass statusCode = "PASS"
	skip statusCode = "SKIP"
)

type status struct {
	code statusCode
	err  error
}

type job struct {
	file   string
	format outputFormat
	result chan status
}

func (j job) run(wd, exe string, fxc bool, dxcPath, xcrunPath string, generateExpected, generateSkip bool) {
	j.result <- func() status {
		// Is there an expected output?
		expected := loadExpectedFile(j.file, j.format)
		skipped := false
		if strings.HasPrefix(expected, "SKIP") { // Special SKIP token
			skipped = true
		}

		expected = strings.ReplaceAll(expected, "\r\n", "\n")

		file, err := filepath.Rel(wd, j.file)
		if err != nil {
			file = j.file
		}

		// Make relative paths use forward slash separators (on Windows) so that paths in tint
		// output match expected output that contain errors
		file = strings.ReplaceAll(file, `\`, `/`)

		args := []string{
			file,
			"--format", string(j.format),
		}

		// Can we validate?
		validate := false
		switch j.format {
		case wgsl:
			validate = true
		case spvasm:
			args = append(args, "--validate") // spirv-val is statically linked, always available
			validate = true
		case hlsl:
			if fxc {
				args = append(args, "--fxc")
				validate = true
			} else if dxcPath != "" {
				args = append(args, "--dxc", dxcPath)
				validate = true
			}
		case msl:
			if xcrunPath != "" {
				args = append(args, "--xcrun", xcrunPath)
				validate = true
			}
		}

		// Invoke the compiler...
		ok, out := invoke(wd, exe, args...)
		out = strings.ReplaceAll(out, "\r\n", "\n")
		matched := expected == "" || expected == out

		if ok && generateExpected && (validate || !skipped) {
			saveExpectedFile(j.file, j.format, out)
			matched = true
		}

		switch {
		case ok && matched:
			// Test passed
			return status{code: pass}

			//       --- Below this point the test has failed ---

		case skipped:
			if generateSkip {
				saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out)
			}
			return status{code: skip}

		case !ok:
			// Compiler returned non-zero exit code
			if generateSkip {
				saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out)
			}
			err := fmt.Errorf("%s", out)
			return status{code: fail, err: err}

		default:
			// Compiler returned zero exit code, or output was not as expected
			if generateSkip {
				saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out)
			}

			// Expected output did not match
			dmp := diffmatchpatch.New()
			diff := dmp.DiffPrettyText(dmp.DiffMain(expected, out, true))
			err := fmt.Errorf(`Output was not as expected

--------------------------------------------------------------------------------
-- Expected:                                                                  --
--------------------------------------------------------------------------------
%s

--------------------------------------------------------------------------------
-- Got:                                                                       --
--------------------------------------------------------------------------------
%s

--------------------------------------------------------------------------------
-- Diff:                                                                      --
--------------------------------------------------------------------------------
%s`,
				expected, out, diff)
			return status{code: fail, err: err}
		}
	}()
}

// loadExpectedFile loads the expected output file for the test file at 'path'
// and the output format 'format'. If the file does not exist, or cannot be
// read, then an empty string is returned.
func loadExpectedFile(path string, format outputFormat) string {
	content, err := ioutil.ReadFile(expectedFilePath(path, format))
	if err != nil {
		return ""
	}
	return string(content)
}

// saveExpectedFile writes the expected output file for the test file at 'path'
// and the output format 'format', with the content 'content'.
func saveExpectedFile(path string, format outputFormat, content string) error {
	return ioutil.WriteFile(expectedFilePath(path, format), []byte(content), 0666)
}

// expectedFilePath returns the expected output file path for the test file at
// 'path' and the output format 'format'.
func expectedFilePath(path string, format outputFormat) string {
	return path + ".expected." + string(format)
}

// indent returns the string 's' indented with 'n' whitespace characters
func indent(s string, n int) string {
	tab := strings.Repeat(" ", n)
	return tab + strings.ReplaceAll(s, "\n", "\n"+tab)
}

// alignLeft returns the string of 'val' padded so that it is aligned left in
// a column of the given width
func alignLeft(val interface{}, width int) string {
	s := fmt.Sprint(val)
	padding := width - utf8.RuneCountInString(s)
	return s + strings.Repeat(" ", padding)
}

// alignCenter returns the string of 'val' padded so that it is centered in a
// column of the given width.
func alignCenter(val interface{}, width int) string {
	s := fmt.Sprint(val)
	padding := width - utf8.RuneCountInString(s)
	return strings.Repeat(" ", padding/2) + s + strings.Repeat(" ", (padding+1)/2)
}

// alignRight returns the string of 'val' padded so that it is aligned right in
// a column of the given width
func alignRight(val interface{}, width int) string {
	s := fmt.Sprint(val)
	padding := width - utf8.RuneCountInString(s)
	return strings.Repeat(" ", padding) + s
}

// maxStringLen returns the maximum number of runes found in all the strings in
// 'l'
func maxStringLen(l []string) int {
	max := 0
	for _, s := range l {
		if c := utf8.RuneCountInString(s); c > max {
			max = c
		}
	}
	return max
}

// formatWidth returns the width in runes for the outputFormat column 'b'
func formatWidth(b outputFormat) int {
	const min = 6
	c := utf8.RuneCountInString(string(b))
	if c < min {
		return min
	}
	return c
}

// percentage returns the percentage of n out of total as a string
func percentage(n, total int) string {
	if total == 0 {
		return "-"
	}
	f := float64(n) / float64(total)
	return fmt.Sprintf("%.1f%c", f*100.0, '%')
}

// invoke runs the executable 'exe' with the provided arguments.
func invoke(wd, exe string, args ...string) (ok bool, output string) {
	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
	defer cancel()

	cmd := exec.CommandContext(ctx, exe, args...)
	cmd.Dir = wd
	out, err := cmd.CombinedOutput()
	str := string(out)
	if err != nil {
		if ctx.Err() == context.DeadlineExceeded {
			return false, fmt.Sprintf("test timed out after %v", testTimeout)
		}
		if str != "" {
			return false, str
		}
		return false, err.Error()
	}
	return true, str
}
