// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

// +build !codes

package testutil

import (
	"bytes"
	"encoding/json"
	"flag"
	"fmt"
	"io/ioutil"
	"os"
	"path/filepath"
	"reflect"
	"regexp"
	"runtime"
	"sort"
	"strings"

	"github.com/pingcap/check"
	"github.com/pingcap/errors"
	"github.com/pingcap/parser/mysql"
	"github.com/pingcap/tidb/kv"
	"github.com/pingcap/tidb/sessionctx/stmtctx"
	"github.com/pingcap/tidb/types"
	"github.com/pingcap/tidb/util/codec"
	"github.com/pingcap/tidb/util/logutil"
	"go.uber.org/zap"
)

// CompareUnorderedStringSlice compare two string slices.
// If a and b is exactly the same except the order, it returns true.
// In otherwise return false.
func CompareUnorderedStringSlice(a []string, b []string) bool {
	if a == nil && b == nil {
		return true
	}
	if a == nil || b == nil {
		return false
	}
	if len(a) != len(b) {
		return false
	}
	m := make(map[string]int, len(a))
	for _, i := range a {
		_, ok := m[i]
		if !ok {
			m[i] = 1
		} else {
			m[i]++
		}
	}

	for _, i := range b {
		_, ok := m[i]
		if !ok {
			return false
		}
		m[i]--
		if m[i] == 0 {
			delete(m, i)
		}
	}
	return len(m) == 0
}

// datumEqualsChecker is a checker for DatumEquals.
type datumEqualsChecker struct {
	*check.CheckerInfo
}

// DatumEquals checker verifies that the obtained value is equal to
// the expected value.
// For example:
//     c.Assert(value, DatumEquals, NewDatum(42))
var DatumEquals check.Checker = &datumEqualsChecker{
	&check.CheckerInfo{Name: "DatumEquals", Params: []string{"obtained", "expected"}},
}

func (checker *datumEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
	defer func() {
		if v := recover(); v != nil {
			result = false
			error = fmt.Sprint(v)
			logutil.BgLogger().Error("panic in datumEqualsChecker.Check",
				zap.Reflect("r", v),
				zap.Stack("stack trace"))
		}
	}()
	paramFirst, ok := params[0].(types.Datum)
	if !ok {
		panic("the first param should be datum")
	}
	paramSecond, ok := params[1].(types.Datum)
	if !ok {
		panic("the second param should be datum")
	}
	sc := new(stmtctx.StatementContext)
	res, err := paramFirst.CompareDatum(sc, &paramSecond)
	if err != nil {
		panic(err)
	}
	return res == 0, ""
}

// MustNewCommonHandle create a common handle with given values.
func MustNewCommonHandle(c *check.C, values ...interface{}) kv.Handle {
	encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeDatums(values...)...)
	c.Assert(err, check.IsNil)
	ch, err := kv.NewCommonHandle(encoded)
	c.Assert(err, check.IsNil)
	return ch
}

// CommonHandleSuite is used to adapt kv.CommonHandle to existing kv.IntHandle tests.
//  Usage:
//   type MyTestSuite struct {
//       CommonHandleSuite
//   }
//   func (s *MyTestSuite) TestSomething(c *C) {
//       // ...
//       s.RerunWithCommonHandleEnabled(c, s.TestSomething)
//   }
type CommonHandleSuite struct {
	IsCommonHandle bool
}

// RerunWithCommonHandleEnabled runs a test function with IsCommonHandle enabled.
func (chs *CommonHandleSuite) RerunWithCommonHandleEnabled(c *check.C, f func(*check.C)) {
	if !chs.IsCommonHandle {
		chs.IsCommonHandle = true
		f(c)
		chs.IsCommonHandle = false
	}
}

// NewHandle create a handle according to CommonHandleSuite.IsCommonHandle.
func (chs *CommonHandleSuite) NewHandle() *commonHandleSuiteNewHandleBuilder {
	return &commonHandleSuiteNewHandleBuilder{isCommon: chs.IsCommonHandle}
}

type commonHandleSuiteNewHandleBuilder struct {
	isCommon   bool
	intVal     int64
	commonVals []interface{}
}

func (c *commonHandleSuiteNewHandleBuilder) Int(v int64) *commonHandleSuiteNewHandleBuilder {
	c.intVal = v
	return c
}

func (c *commonHandleSuiteNewHandleBuilder) Common(vs ...interface{}) kv.Handle {
	c.commonVals = vs
	return c.Build()
}

func (c *commonHandleSuiteNewHandleBuilder) Build() kv.Handle {
	if c.isCommon {
		encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeDatums(c.commonVals...)...)
		if err != nil {
			panic(err)
		}
		ch, err := kv.NewCommonHandle(encoded)
		if err != nil {
			panic(err)
		}
		return ch
	}
	return kv.IntHandle(c.intVal)
}

type handleEqualsChecker struct {
	*check.CheckerInfo
}

// HandleEquals checker verifies that the obtained handle is equal to
// the expected handle.
// For example:
//     c.Assert(value, HandleEquals, kv.IntHandle(42))
var HandleEquals = &handleEqualsChecker{
	&check.CheckerInfo{Name: "HandleEquals", Params: []string{"obtained", "expected"}},
}

func (checker *handleEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
	if params[0] == nil && params[1] == nil {
		return true, ""
	}
	param1, ok1 := params[0].(kv.Handle)
	param2, ok2 := params[1].(kv.Handle)
	if !ok1 || !ok2 {
		return false, "Argument to " + checker.Name + " must be kv.Handle"
	}
	if param1.IsInt() != param2.IsInt() {
		return false, "Two handle types arguments to" + checker.Name + " must be same"
	}

	return param1.String() == param2.String(), ""
}

// RowsWithSep is a convenient function to wrap args to a slice of []interface.
// The arg represents a row, split by sep.
func RowsWithSep(sep string, args ...string) [][]interface{} {
	rows := make([][]interface{}, len(args))
	for i, v := range args {
		strs := strings.Split(v, sep)
		row := make([]interface{}, len(strs))
		for j, s := range strs {
			row[j] = s
		}
		rows[i] = row
	}
	return rows
}

// record is a flag used for generate test result.
var record bool

func init() {
	flag.BoolVar(&record, "record", false, "to generate test result")
}

type testCases struct {
	Name       string
	Cases      *json.RawMessage // For delayed parse.
	decodedOut interface{}      // For generate output.
}

// TestData stores all the data of a test suite.
type TestData struct {
	input          []testCases
	output         []testCases
	filePathPrefix string
	funcMap        map[string]int
}

// LoadTestSuiteData loads test suite data from file.
func LoadTestSuiteData(dir, suiteName string) (res TestData, err error) {
	res.filePathPrefix = filepath.Join(dir, suiteName)
	res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix))
	if err != nil {
		return res, err
	}
	if record {
		res.output = make([]testCases, len(res.input))
		for i := range res.input {
			res.output[i].Name = res.input[i].Name
		}
	} else {
		res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix))
		if err != nil {
			return res, err
		}
		if len(res.input) != len(res.output) {
			return res, errors.New(fmt.Sprintf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output)))
		}
	}
	res.funcMap = make(map[string]int, len(res.input))
	for i, test := range res.input {
		res.funcMap[test.Name] = i
		if test.Name != res.output[i].Name {
			return res, errors.New(fmt.Sprintf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name))
		}
	}
	return res, nil
}

func loadTestSuiteCases(filePath string) (res []testCases, err error) {
	jsonFile, err := os.Open(filePath)
	if err != nil {
		return res, err
	}
	defer func() {
		if err1 := jsonFile.Close(); err == nil && err1 != nil {
			err = err1
		}
	}()
	byteValue, err := ioutil.ReadAll(jsonFile)
	if err != nil {
		return res, err
	}
	// Remove comments, since they are not allowed in json.
	re := regexp.MustCompile("(?s)//.*?\n")
	err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res)
	return res, err
}

// GetTestCasesByName gets the test cases for a test function by its name.
func (t *TestData) GetTestCasesByName(caseName string, c *check.C, in interface{}, out interface{}) {
	casesIdx, ok := t.funcMap[caseName]
	c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", caseName))
	err := json.Unmarshal(*t.input[casesIdx].Cases, in)
	c.Assert(err, check.IsNil)
	if !record {
		err = json.Unmarshal(*t.output[casesIdx].Cases, out)
		c.Assert(err, check.IsNil)
	} else {
		// Init for generate output file.
		inputLen := reflect.ValueOf(in).Elem().Len()
		v := reflect.ValueOf(out).Elem()
		if v.Kind() == reflect.Slice {
			v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
		}
	}
	t.output[casesIdx].decodedOut = out
}

// GetTestCases gets the test cases for a test function.
func (t *TestData) GetTestCases(c *check.C, in interface{}, out interface{}) {
	// Extract caller's name.
	pc, _, _, ok := runtime.Caller(1)
	c.Assert(ok, check.IsTrue)
	details := runtime.FuncForPC(pc)
	funcNameIdx := strings.LastIndex(details.Name(), ".")
	funcName := details.Name()[funcNameIdx+1:]

	casesIdx, ok := t.funcMap[funcName]
	c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", funcName))
	err := json.Unmarshal(*t.input[casesIdx].Cases, in)
	c.Assert(err, check.IsNil)
	if !record {
		err = json.Unmarshal(*t.output[casesIdx].Cases, out)
		c.Assert(err, check.IsNil)
	} else {
		// Init for generate output file.
		inputLen := reflect.ValueOf(in).Elem().Len()
		v := reflect.ValueOf(out).Elem()
		if v.Kind() == reflect.Slice {
			v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
		}
	}
	t.output[casesIdx].decodedOut = out
}

// OnRecord execute the function to update result.
func (t *TestData) OnRecord(updateFunc func()) {
	if record {
		updateFunc()
	}
}

// ConvertRowsToStrings converts [][]interface{} to []string.
func (t *TestData) ConvertRowsToStrings(rows [][]interface{}) (rs []string) {
	for _, row := range rows {
		s := fmt.Sprintf("%v", row)
		// Trim the leftmost `[` and rightmost `]`.
		s = s[1 : len(s)-1]
		rs = append(rs, s)
	}
	return rs
}

// ConvertSQLWarnToStrings converts []SQLWarn to []string.
func (t *TestData) ConvertSQLWarnToStrings(warns []stmtctx.SQLWarn) (rs []string) {
	for _, warn := range warns {
		rs = append(rs, fmt.Sprint(warn.Err.Error()))
	}
	return rs
}

// GenerateOutputIfNeeded generate the output file.
func (t *TestData) GenerateOutputIfNeeded() error {
	if !record {
		return nil
	}

	buf := new(bytes.Buffer)
	enc := json.NewEncoder(buf)
	enc.SetEscapeHTML(false)
	enc.SetIndent("", "  ")
	for i, test := range t.output {
		err := enc.Encode(test.decodedOut)
		if err != nil {
			return err
		}
		res := make([]byte, len(buf.Bytes()))
		copy(res, buf.Bytes())
		buf.Reset()
		rm := json.RawMessage(res)
		t.output[i].Cases = &rm
	}
	err := enc.Encode(t.output)
	if err != nil {
		return err
	}
	file, err := os.Create(fmt.Sprintf("%s_out.json", t.filePathPrefix))
	if err != nil {
		return err
	}
	defer func() {
		if err1 := file.Close(); err == nil && err1 != nil {
			err = err1
		}
	}()
	_, err = file.Write(buf.Bytes())
	return err
}

// MaskSortHandles sorts the handles by lowest (fieldTypeBits - 1 - shardBitsCount) bits.
func MaskSortHandles(handles []int64, shardBitsCount int, fieldType byte) []int64 {
	typeBitsLength := mysql.DefaultLengthOfMysqlTypes[fieldType] * 8
	const signBitCount = 1
	shiftBitsCount := 64 - typeBitsLength + shardBitsCount + signBitCount
	ordered := make([]int64, len(handles))
	for i, h := range handles {
		ordered[i] = h << shiftBitsCount >> shiftBitsCount
	}
	sort.Slice(ordered, func(i, j int) bool { return ordered[i] < ordered[j] })
	return ordered
}
