diff --git a/auth/caching_sha2.go b/auth/caching_sha2.go index 9fb684191..803e3a455 100644 --- a/auth/caching_sha2.go +++ b/auth/caching_sha2.go @@ -44,6 +44,7 @@ const ( MIXCHARS = 32 SALT_LENGTH = 20 ITERATION_MULTIPLIER = 1000 + MAX_PASSWORD_LENGTH = 256 ) func b64From24bit(b []byte, n int) []byte { @@ -218,3 +219,29 @@ func NewSha2Password(pwd string) string { return sha256crypt(pwd, salt, 5*ITERATION_MULTIPLIER) } + +// GenerateScramble creates a scrambe that can be used for caching_sha2 fast authentication +// See also: https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html +func GenerateScramble(password []byte, nonce []byte) ([]byte, error) { + if len(password) > MAX_PASSWORD_LENGTH { + return nil, errors.New("invalid password length for caching_sha2_password scramble generation") + } + + if len(nonce) != SALT_LENGTH { + return nil, errors.New("invalid nonce length for caching_sha2_password scramble generation") + } + + digestStage1 := sha256.Sum256(password) + digestStage2 := sha256.Sum256(digestStage1[:]) + tempStage3 := sha256.New() + tempStage3.Write(digestStage2[:]) + tempStage3.Write(nonce) + digestStage3 := tempStage3.Sum(nil) + + newscramble := digestStage1 + for i := range digestStage1 { + newscramble[i] ^= digestStage3[i] + } + + return newscramble[:], nil +} diff --git a/auth/caching_sha2_test.go b/auth/caching_sha2_test.go index 32a7cc57c..5ccabb4ff 100644 --- a/auth/caching_sha2_test.go +++ b/auth/caching_sha2_test.go @@ -15,6 +15,7 @@ package auth import ( "encoding/hex" + "strings" . "github.com/pingcap/check" ) @@ -71,3 +72,21 @@ func (s *testAuthSuite) TestNewSha2Password(c *C) { c.Assert(pwhash[r], Not(Equals), 36) // '$' } } + +func (s *testAuthSuite) TestGenerateScramble(c *C) { + pwd := []byte("abc") + nonce, _ := hex.DecodeString("6d642b676464321c561d476e094c316d05180b16") + storedScramble, _ := hex.DecodeString("3455aae998ae2959d7170229375c9626735bf545d2a828a7d45f94f9b2cf19f4") + scramble, err := GenerateScramble(pwd, nonce) + c.Assert(err, IsNil) + c.Assert(scramble, BytesEquals, storedScramble) + + _, err = GenerateScramble([]byte(strings.Repeat("x", MAX_PASSWORD_LENGTH+1)), nonce) + c.Assert(err, NotNil) + + _, err = GenerateScramble(pwd, append(nonce, byte('A'))) + c.Assert(err, NotNil) + + _, err = GenerateScramble(pwd, nonce[:SALT_LENGTH-1]) + c.Assert(err, NotNil) +}