Skip to content

Placeholder interpolation #309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Feb 14, 2015
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e35fa00
Implement placeholder substitution.
methane Dec 31, 2014
c8c9bb1
Query() uses client-side placeholder substitution.
methane Dec 31, 2014
cac6129
Don't send text query larger than maxPacketAllowed
methane Jan 20, 2015
b7c2c47
Add substitutePlaceholder option to DSN
methane Jan 20, 2015
f3b82fd
Merge remote-tracking branch 'upstream/pr/297'
arvenil Jan 31, 2015
058ce87
Move escape funcs to utils.go, export them, add references to mysql s…
arvenil Feb 1, 2015
42956fa
Add tests for escaping functions
arvenil Feb 1, 2015
e6bf23a
Add basic SQL injection tests, including NO_BACKSLASH_ESCAPES sql_mode
arvenil Feb 1, 2015
b473259
Test if inserted data is correctly retrieved after being escaped
arvenil Feb 7, 2015
42a1efd
Don't stop test on MySQLWarnings
methane Feb 8, 2015
3c8fa90
substitutePlaceholder -> interpolateParams
methane Feb 8, 2015
6c8484b
Add interpolateParams document to README
methane Feb 8, 2015
04866ee
Fix nits pointed in pull request.
methane Feb 8, 2015
dd7b87c
Add benchmark for interpolateParams()
methane Feb 8, 2015
9faabe5
Don't write microseconds when Time.Nanosecond() == 0
methane Feb 8, 2015
468b9e5
Fix benchmark
methane Feb 8, 2015
0297315
Reduce allocs in interpolateParams.
methane Feb 8, 2015
0b75396
Inline datetime formatting
methane Feb 8, 2015
9f84dfb
Remove one more allocation
methane Feb 8, 2015
8826242
More acculate estimation of upper bound
methane Feb 8, 2015
916a1f2
escapeString -> escapeBackslash
methane Feb 9, 2015
88aeb98
append string... to []byte without cast.
methane Feb 10, 2015
43536c7
Specialize escape functions for string
methane Feb 10, 2015
0c7ae46
test for escapeString*
methane Feb 10, 2015
c285e39
Use digits10 and digits01 to format datetime.
methane Feb 10, 2015
fcea447
Round under microsecond
methane Feb 10, 2015
bfbe6c5
travis: Drop Go 1.1 and add Go 1.4
methane Feb 11, 2015
d65f96a
Fix typo
methane Feb 12, 2015
e11c825
Inlining mysqlConn.escapeBytes and mysqlConn.escapeString
methane Feb 12, 2015
b4f0315
Bit detailed info about vulnerability when using multibyte encoding.
methane Feb 12, 2015
1fd0514
Add link to StackOverflow describe vulnerability using multibyte enco…
methane Feb 12, 2015
20b75cd
Fix comment
methane Feb 12, 2015
e517683
Allow interpolateParams only with ascii, latin1 and utf8 collations
methane Feb 12, 2015
0f22bc2
extract function to reserve buffer
methane Feb 12, 2015
52a5860
Fix missing db.Close()
methane Feb 12, 2015
2a634df
Fix sentence in interpolateParams document.
methane Feb 12, 2015
90cb6c3
Use blacklist to avoid vulnerability with interpolation
methane Feb 12, 2015
9437b61
Adding myself to AUTHORS (however, 99% work done by @methane ;))
arvenil Feb 13, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 132 additions & 49 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"database/sql/driver"
"errors"
"net"
"strconv"
"strings"
"time"
)
Expand All @@ -26,26 +27,28 @@ type mysqlConn struct {
maxPacketAllowed int
maxWriteSize int
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
strict bool
}

type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
columnsWithAlias bool
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
columnsWithAlias bool
substitutePlaceholder bool
}

// Handles parameters set in DSN after the connection is established
Expand Down Expand Up @@ -162,28 +165,101 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
return stmt, err
}

// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
func (mc *mysqlConn) escapeBytes(v []byte) string {
var escape func([]byte) []byte
if mc.status&statusNoBackslashEscapes == 0 {
escape = EscapeString
} else {
escape = EscapeQuotes
}
return "'" + string(escape(v)) + "'"
}

func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about substitutePlaceholder instead of buildQuery?

chunks := strings.Split(query, "?")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be taken with a grain of salt - this is only called once per query. Still...

I'd like to reduce the number of allocations - that's also why I sent you the playground link.
Start with a buffer: make([]byte, 0, 5*len(query)/4) - byte slice with 0 length and 1.25 times capacity of original query length (1.25 is just a best guess for a sensible starting point).
Use the Append... based functions from strconv.
Try looping with strings.IndexByte(..., '?') instead of splitting on '?'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll rewrite it in two-pass way. First pass is for calculating required length. Second pass is for building []bytes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've forgot about formatting float.
I'll take Append based approach.

if len(chunks) != len(args)+1 {
return "", driver.ErrSkip
}

parts := make([]string, len(chunks)+len(args))
parts[0] = chunks[0]

for i, arg := range args {
pos := i*2 + 1
parts[pos+1] = chunks[i+1]
if arg == nil {
parts[pos] = "NULL"
continue
}
switch v := arg.(type) {
case int64:
parts[pos] = strconv.FormatInt(v, 10)
case float64:
parts[pos] = strconv.FormatFloat(v, 'f', -1, 64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'g' instead of 'f'? That may be a little smaller in some cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

case bool:
if v {
parts[pos] = "1"
} else {
parts[pos] = "0"
}
case time.Time:
if v.IsZero() {
parts[pos] = "'0000-00-00'"
} else {
fmt := "'2006-01-02 15:04:05.999999'"
parts[pos] = v.In(mc.cfg.loc).Format(fmt)
}
case []byte:
if v == nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed, handled in L191

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Required since nil and []byte(nil) is different.

parts[pos] = "NULL"
} else {
parts[pos] = mc.escapeBytes(v)
}
case string:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When []byte and string are essentially the same code, can they be in one case []byte, string:? Can escapeBytes be inlined?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I don't know that I can use case []byte, string: in type switch switch v := arg.(type).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't...
./connection.go:214: cannot convert v (type driver.Value) to type []byte: need type assertion

parts[pos] = mc.escapeBytes([]byte(v))
default:
return "", driver.ErrSkip
}
}
pktSize := len(query) + 4 // 4 bytes for header.
for _, p := range parts {
pktSize += len(p)
}
if pktSize > mc.maxPacketAllowed {
return "", driver.ErrSkip
}
return strings.Join(parts, ""), nil
}

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
mc.affectedRows = 0
mc.insertId = 0

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
if len(args) != 0 {
if !mc.cfg.substitutePlaceholder {
return nil, driver.ErrSkip
}
return nil, err
// try client-side prepare to reduce roundtrip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a bit confusing (same in L368). Maybe use something like this instead:
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement

prepared, err := mc.buildQuery(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
mc.affectedRows = 0
mc.insertId = 0

// with args, must use prepared stmt
return nil, driver.ErrSkip

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}

// Internal function to execute commands
Expand Down Expand Up @@ -212,31 +288,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if len(args) != 0 {
if !mc.cfg.substitutePlaceholder {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.buildQuery(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
return rows, err
rows := new(textRows)
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
return rows, err
}
return nil, err
}

// with args, must use prepared stmt
return nil, driver.ErrSkip
return nil, err
}

// Gets the value of the given MySQL System Variable
Expand Down
22 changes: 22 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,25 @@ const (
flagUnknown3
flagUnknown4
)

// http://dev.mysql.com/doc/internals/en/status-flags.html

type statusFlag uint16

const (
statusInTrans statusFlag = 1 << iota
statusInAutocommit
statusUnknown1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

statusReserved // not in documentation please

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

statusMoreResultsExists
statusNoGoodIndexUsed
statusNoIndexUsed
statusCursorExists
statusLastRowSent
statusDbDropped
statusNoBackslashEscapes
statusMetadataChanged
statusQueryWasSlow
statusPsOutParams
statusInTransReadonly
statusSessionStateChanged
)
68 changes: 68 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {

db.Exec("DROP TABLE IF EXISTS test")

dbp, err := sql.Open("mysql", dsn+"&substitutePlaceholder=true")
if err != nil {
t.Fatalf("Error connecting: %s", err.Error())
}
defer dbp.Close()

dbt := &DBTest{t, db}
dbtp := &DBTest{t, dbp}
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
test(dbtp)
dbtp.db.Exec("DROP TABLE IF EXISTS test")
}
}

Expand Down Expand Up @@ -1538,3 +1547,62 @@ func TestCustomDial(t *testing.T) {
t.Fatalf("Connection failed: %s", err.Error())
}
}

func TestSqlInjection(t *testing.T) {
createTest := func(arg string) func(dbt *DBTest) {
return func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (v INTEGER)")
dbt.mustExec("INSERT INTO test VALUES (?)", 1)

var v int
// NULL can't be equal to anything, the idea here is to inject query so it returns row
// This test verifies that EscapeQuotes and EscapeStrings are working properly
err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v)
if err == sql.ErrNoRows {
return // success, sql injection failed
} else if err == nil {
dbt.Errorf("Sql injection successful with arg: %s", arg)
} else {
dbt.Errorf("Error running query with arg: %s; err: %s", arg, err.Error())
}
}
}

dsns := []string{
dsn,
dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
}
for _, testdsn := range dsns {
runTests(t, testdsn, createTest("1 OR 1=1"))
runTests(t, testdsn, createTest("' OR '1'='1"))
}
}

// Test if inserted data is correctly retrieved after being escaped
func TestInsertRetrieveEscapedData(t *testing.T) {
testData := func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (v VARCHAR(255))")

// All sequences that are escaped by EscapeQuotes and EscapeString
v := "foo \x00\n\r\x1a\"'\\"
dbt.mustExec("INSERT INTO test VALUES (?)", v)

var out string
err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out)
if err != nil {
dbt.Fatalf("%s", err.Error())
}

if out != v {
dbt.Errorf("%q != %q", out, v)
}
}

dsns := []string{
dsn,
dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
}
for _, testdsn := range dsns {
runTests(t, testdsn, testData)
}
}
1 change: 1 addition & 0 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])

// server_status [2 bytes]
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8

// warning count [2 bytes]
if !mc.strict {
Expand Down
Loading