// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package session

import (
	"encoding/json"
	"strconv"
	"strings"
	"time"

	"github.com/pingcap/errors"
	"github.com/pingcap/tidb/util/hack"
)

func reserveBuffer(buf []byte, appendSize int) []byte {
	newSize := len(buf) + appendSize
	if cap(buf) < newSize {
		newBuf := make([]byte, len(buf)*2+appendSize)
		copy(newBuf, buf)
		buf = newBuf
	}
	return buf[:newSize]
}

// escapeBytesBackslash will escape []byte into the buffer, with backslash.
func escapeBytesBackslash(buf []byte, v []byte) []byte {
	pos := len(buf)
	buf = reserveBuffer(buf, len(v)*2)

	for _, c := range v {
		switch c {
		case '\x00':
			buf[pos] = '\\'
			buf[pos+1] = '0'
			pos += 2
		case '\n':
			buf[pos] = '\\'
			buf[pos+1] = 'n'
			pos += 2
		case '\r':
			buf[pos] = '\\'
			buf[pos+1] = 'r'
			pos += 2
		case '\x1a':
			buf[pos] = '\\'
			buf[pos+1] = 'Z'
			pos += 2
		case '\'':
			buf[pos] = '\\'
			buf[pos+1] = '\''
			pos += 2
		case '"':
			buf[pos] = '\\'
			buf[pos+1] = '"'
			pos += 2
		case '\\':
			buf[pos] = '\\'
			buf[pos+1] = '\\'
			pos += 2
		default:
			buf[pos] = c
			pos++
		}
	}

	return buf[:pos]
}

// escapeStringBackslash will escape string into the buffer, with backslash.
func escapeStringBackslash(buf []byte, v string) []byte {
	return escapeBytesBackslash(buf, hack.Slice(v))
}

// EscapeSQL will escape input arguments into the sql string, doing necessary processing.
// It works like printf() in c, there are following format specifiers:
// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..)
// 2. %%: output %
// 3. %n: for identifiers, for example ("use %n", db)
// But it does not prevent you from doing EscapeSQL("select '%?", ";SQL injection!;") => "select '';SQL injection!;'".
// It is still your responsibility to write safe SQL.
func EscapeSQL(sql string, args ...interface{}) (string, error) {
	buf := make([]byte, 0, len(sql))
	argPos := 0
	for i := 0; i < len(sql); i++ {
		q := strings.IndexByte(sql[i:], '%')
		if q == -1 {
			buf = append(buf, sql[i:]...)
			break
		}
		buf = append(buf, sql[i:i+q]...)
		i += q

		ch := byte(0)
		if i+1 < len(sql) {
			ch = sql[i+1] // get the specifier
		}
		switch ch {
		case 'n':
			if argPos >= len(args) {
				return "", errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
			}
			arg := args[argPos]
			argPos++

			v, ok := arg.(string)
			if !ok {
				return "", errors.Errorf("expect a string identifier, got %v", arg)
			}
			buf = append(buf, '`')
			buf = append(buf, strings.Replace(v, "`", "``", -1)...)
			buf = append(buf, '`')
			i++ // skip specifier
		case '?':
			if argPos >= len(args) {
				return "", errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
			}
			arg := args[argPos]
			argPos++

			if arg == nil {
				buf = append(buf, "NULL"...)
			} else {
				switch v := arg.(type) {
				case int:
					buf = strconv.AppendInt(buf, int64(v), 10)
				case int8:
					buf = strconv.AppendInt(buf, int64(v), 10)
				case int16:
					buf = strconv.AppendInt(buf, int64(v), 10)
				case int32:
					buf = strconv.AppendInt(buf, int64(v), 10)
				case int64:
					buf = strconv.AppendInt(buf, v, 10)
				case uint:
					buf = strconv.AppendUint(buf, uint64(v), 10)
				case uint8:
					buf = strconv.AppendUint(buf, uint64(v), 10)
				case uint16:
					buf = strconv.AppendUint(buf, uint64(v), 10)
				case uint32:
					buf = strconv.AppendUint(buf, uint64(v), 10)
				case uint64:
					buf = strconv.AppendUint(buf, v, 10)
				case float32:
					buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
				case float64:
					buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
				case bool:
					if v {
						buf = append(buf, '1')
					} else {
						buf = append(buf, '0')
					}
				case time.Time:
					if v.IsZero() {
						buf = append(buf, "'0000-00-00'"...)
					} else {
						buf = append(buf, '\'')
						buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999")
						buf = append(buf, '\'')
					}
				case json.RawMessage:
					buf = append(buf, '\'')
					buf = escapeBytesBackslash(buf, v)
					buf = append(buf, '\'')
				case []byte:
					if v == nil {
						buf = append(buf, "NULL"...)
					} else {
						buf = append(buf, "_binary'"...)
						buf = escapeBytesBackslash(buf, v)
						buf = append(buf, '\'')
					}
				case string:
					buf = append(buf, '\'')
					buf = escapeStringBackslash(buf, v)
					buf = append(buf, '\'')
				case []string:
					buf = append(buf, '(')
					for i, k := range v {
						if i > 0 {
							buf = append(buf, ',')
						}
						buf = append(buf, '\'')
						buf = escapeStringBackslash(buf, k)
						buf = append(buf, '\'')
					}
					buf = append(buf, ')')
				default:
					return "", errors.Errorf("unsupported %d-th argument: %v", argPos, arg)
				}
			}
			i++ // skip specifier
		case '%':
			buf = append(buf, '%')
			i++ // skip specifier
		default:
			buf = append(buf, '%')
		}
	}
	return string(buf), nil
}
