update mssql drive to last working version 20180314172330-6a30f4e59a44 (#7306)
This commit is contained in:
parent
aeb8f7aad8
commit
1e46eedce7
2
go.mod
2
go.mod
|
@ -140,4 +140,4 @@ require (
|
||||||
xorm.io/core v0.6.3
|
xorm.io/core v0.6.3
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/denisenkom/go-mssqldb => github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449
|
replace github.com/denisenkom/go-mssqldb => github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -60,8 +60,8 @@ github.com/cznic/strutil v0.0.0-20181122101858-275e90344537/go.mod h1:AHHPPPXTw0
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449 h1:JpA+YMG4JLW8nzLmU05mTiuB0O17xHGxpWolEZ0zDuA=
|
github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44 h1:x0uHqLQTSEL9LKic8sWDt3ASkq07ve5ojIIUl5uF64M=
|
||||||
github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc=
|
github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44/go.mod h1:xN/JuLBIz4bjkxNmByTiV1IbhfnYb6oo99phBn4Eqhc=
|
||||||
github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac h1:xrQJVwQCGqDvOO7/0+RyIq5J2M3Q4ZF7Ug/BMQtML1E=
|
github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac h1:xrQJVwQCGqDvOO7/0+RyIq5J2M3Q4ZF7Ug/BMQtML1E=
|
||||||
github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||||
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8=
|
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8=
|
||||||
|
|
|
@ -1,79 +1,131 @@
|
||||||
# A pure Go MSSQL driver for Go's database/sql package
|
# A pure Go MSSQL driver for Go's database/sql package
|
||||||
|
|
||||||
|
[![GoDoc](https://godoc.org/github.com/denisenkom/go-mssqldb?status.svg)](http://godoc.org/github.com/denisenkom/go-mssqldb)
|
||||||
|
[![Build status](https://ci.appveyor.com/api/projects/status/jrln8cs62wj9i0a2?svg=true)](https://ci.appveyor.com/project/denisenkom/go-mssqldb)
|
||||||
|
[![codecov](https://codecov.io/gh/denisenkom/go-mssqldb/branch/master/graph/badge.svg)](https://codecov.io/gh/denisenkom/go-mssqldb)
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
go get github.com/denisenkom/go-mssqldb
|
Requires Go 1.8 or above.
|
||||||
|
|
||||||
## Tests
|
Install with `go get github.com/denisenkom/go-mssqldb` .
|
||||||
|
|
||||||
`go test` is used for testing. A running instance of MSSQL server is required.
|
## Connection Parameters and DSN
|
||||||
Environment variables are used to pass login information.
|
|
||||||
|
|
||||||
Example:
|
The recommended connection string uses a URL format:
|
||||||
|
`sqlserver://username:password@host/instance?param1=value¶m2=value`
|
||||||
|
Other supported formats are listed below.
|
||||||
|
|
||||||
env HOST=localhost SQLUSER=sa SQLPASSWORD=sa DATABASE=test go test
|
### Common parameters:
|
||||||
|
|
||||||
## Connection Parameters
|
* `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used.
|
||||||
|
* `password`
|
||||||
|
* `database`
|
||||||
|
* `connection timeout` - in seconds (default is 30)
|
||||||
|
* `dial timeout` - in seconds (default is 5)
|
||||||
|
* `encrypt`
|
||||||
|
* `disable` - Data send between client and server is not encrypted.
|
||||||
|
* `false` - Data sent between client and server is not encrypted beyond the login packet. (Default)
|
||||||
|
* `true` - Data sent between client and server is encrypted.
|
||||||
|
* `keepAlive` - in seconds; 0 to disable (default is 30)
|
||||||
|
* `app name` - The application name (default is go-mssqldb)
|
||||||
|
|
||||||
* "server" - host or host\instance (default localhost)
|
### Connection parameters for ODBC and ADO style connection strings:
|
||||||
* "port" - used only when there is no instance in server (default 1433)
|
|
||||||
* "failoverpartner" - host or host\instance (default is no partner).
|
* `server` - host or host\instance (default localhost)
|
||||||
* "failoverport" - used only when there is no instance in failoverpartner (default 1433)
|
* `port` - used only when there is no instance in server (default 1433)
|
||||||
* "user id" - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used.
|
|
||||||
* "password"
|
### Less common parameters:
|
||||||
* "database"
|
|
||||||
* "connection timeout" - in seconds (default is 30)
|
* `failoverpartner` - host or host\instance (default is no partner).
|
||||||
* "dial timeout" - in seconds (default is 5)
|
* `failoverport` - used only when there is no instance in failoverpartner (default 1433)
|
||||||
* "keepAlive" - in seconds; 0 to disable (default is 0)
|
* `packet size` - in bytes; 512 to 32767 (default is 4096)
|
||||||
* "log" - logging flags (default 0/no logging, 63 for full logging)
|
* Encrypted connections have a maximum packet size of 16383 bytes
|
||||||
|
* Further information on usage: https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
|
||||||
|
* `log` - logging flags (default 0/no logging, 63 for full logging)
|
||||||
* 1 log errors
|
* 1 log errors
|
||||||
* 2 log messages
|
* 2 log messages
|
||||||
* 4 log rows affected
|
* 4 log rows affected
|
||||||
* 8 trace sql statements
|
* 8 trace sql statements
|
||||||
* 16 log statement parameters
|
* 16 log statement parameters
|
||||||
* 32 log transaction begin/end
|
* 32 log transaction begin/end
|
||||||
* "encrypt"
|
* `TrustServerCertificate`
|
||||||
* disable - Data send between client and server is not encrypted.
|
|
||||||
* false - Data sent between client and server is not encrypted beyond the login packet. (Default)
|
|
||||||
* true - Data sent between client and server is encrypted.
|
|
||||||
* "TrustServerCertificate"
|
|
||||||
* false - Server certificate is checked. Default is false if encypt is specified.
|
* false - Server certificate is checked. Default is false if encypt is specified.
|
||||||
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
|
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
|
||||||
* "certificate" - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
|
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
|
||||||
* "hostNameInCertificate" - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
|
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
|
||||||
* "ServerSPN" - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
|
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
|
||||||
* "Workstation ID" - The workstation name (default is the host name)
|
* `Workstation ID` - The workstation name (default is the host name)
|
||||||
* "app name" - The application name (default is go-mssqldb)
|
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener.
|
||||||
* "ApplicationIntent" - Can be given the value "ReadOnly" to initiate a read-only connection to an Availability Group listener.
|
|
||||||
|
|
||||||
Example:
|
### The connection string can be specified in one of three formats:
|
||||||
|
|
||||||
|
|
||||||
|
1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as
|
||||||
|
the first segment in the path. All other options are query parameters. Examples:
|
||||||
|
|
||||||
|
* `sqlserver://username:password@host/instance?param1=value¶m2=value`
|
||||||
|
* `sqlserver://username:password@host:port?param1=value¶m2=value`
|
||||||
|
* `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance.
|
||||||
|
* `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass.
|
||||||
|
* `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30"` // port 1234 on localhost.
|
||||||
|
* `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass"
|
||||||
|
|
||||||
|
A string of this format can be constructed using the `URL` type in the `net/url` package.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
db, err := sql.Open("mssql", "server=localhost;user id=sa")
|
query := url.Values{}
|
||||||
|
query.Add("connection timeout", "30")
|
||||||
|
|
||||||
|
u := &url.URL{
|
||||||
|
Scheme: "sqlserver",
|
||||||
|
User: url.UserPassword(username, password),
|
||||||
|
Host: fmt.Sprintf("%s:%d", hostname, port),
|
||||||
|
// Path: instance, // if connecting to an instance instead of a port
|
||||||
|
RawQuery: query.Encode(),
|
||||||
|
}
|
||||||
|
db, err := sql.Open("sqlserver", u.String())
|
||||||
|
```
|
||||||
|
|
||||||
|
2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored.
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
* `server=localhost\\SQLExpress;user id=sa;database=master;connection timeout=30`
|
||||||
|
* `server=localhost;user id=sa;database=master;connection timeout=30`
|
||||||
|
|
||||||
|
3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping
|
||||||
|
values in `{}`. Examples:
|
||||||
|
|
||||||
|
* `odbc:server=localhost\\SQLExpress;user id=sa;database=master;connection timeout=30`
|
||||||
|
* `odbc:server=localhost;user id=sa;database=master;connection timeout=30`
|
||||||
|
* `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar"
|
||||||
|
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar"
|
||||||
|
* `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar "
|
||||||
|
* `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar"
|
||||||
|
* `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar"
|
||||||
|
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar"
|
||||||
|
* `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar"
|
||||||
|
|
||||||
|
## Executing Stored Procedures
|
||||||
|
|
||||||
|
To run a stored procedure, set the query text to the procedure name:
|
||||||
|
```go
|
||||||
|
var account = "abc"
|
||||||
|
_, err := db.ExecContext(ctx, "sp_RunMe",
|
||||||
|
sql.Named("ID", 123),
|
||||||
|
sql.Out{Dest{sql.Named("Account", &account)}
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Statement Parameters
|
## Statement Parameters
|
||||||
|
|
||||||
In the SQL statement text, literals may be replaced by a parameter that matches one of the following:
|
The `sqlserver` driver uses normal MS SQL Server syntax and expects parameters in
|
||||||
|
the sql query to be in the form of either `@Name` or `@p1` to `@pN` (ordinal position).
|
||||||
* ?
|
|
||||||
* ?nnn
|
|
||||||
* :nnn
|
|
||||||
* $nnn
|
|
||||||
|
|
||||||
where nnn represents an integer that specifies a 1-indexed positional parameter. Ex:
|
|
||||||
|
|
||||||
```go
|
```go
|
||||||
db.Query("SELECT * FROM t WHERE a = ?3, b = ?2, c = ?1", "x", "y", "z")
|
db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob")
|
||||||
```
|
```
|
||||||
|
|
||||||
will expand to roughly
|
|
||||||
|
|
||||||
```sql
|
|
||||||
SELECT * FROM t WHERE a = 'z', b = 'y', c = 'x'
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
* Can be used with SQL Server 2005 or newer
|
* Can be used with SQL Server 2005 or newer
|
||||||
|
@ -87,6 +139,34 @@ SELECT * FROM t WHERE a = 'z', b = 'y', c = 'x'
|
||||||
* Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas.
|
* Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas.
|
||||||
* Supports query notifications
|
* Supports query notifications
|
||||||
|
|
||||||
|
## Tests
|
||||||
|
|
||||||
|
`go test` is used for testing. A running instance of MSSQL server is required.
|
||||||
|
Environment variables are used to pass login information.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
env SQLSERVER_DSN=sqlserver://user:pass@hostname/instance?database=test1 go test
|
||||||
|
|
||||||
|
## Deprecated
|
||||||
|
|
||||||
|
These features still exist in the driver, but they are are deprecated.
|
||||||
|
|
||||||
|
### Query Parameter Token Replace (driver "mssql")
|
||||||
|
|
||||||
|
If you use the driver name "mssql" (rather then "sqlserver" the SQL text
|
||||||
|
will be loosly parsed and an attempt to extract identifiers using one of
|
||||||
|
|
||||||
|
* ?
|
||||||
|
* ?nnn
|
||||||
|
* :nnn
|
||||||
|
* $nnn
|
||||||
|
|
||||||
|
will be used. This is not recommended with SQL Server.
|
||||||
|
There is at least one existing `won't fix` issue with the query parsing.
|
||||||
|
|
||||||
|
Use the native "@Name" parameters instead with the "sqlserver" driver name.
|
||||||
|
|
||||||
## Known Issues
|
## Known Issues
|
||||||
|
|
||||||
* SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled.
|
* SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled.
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
version: 1.0.{build}
|
||||||
|
|
||||||
|
os: Windows Server 2012 R2
|
||||||
|
|
||||||
|
clone_folder: c:\gopath\src\github.com\denisenkom\go-mssqldb
|
||||||
|
|
||||||
|
environment:
|
||||||
|
GOPATH: c:\gopath
|
||||||
|
HOST: localhost
|
||||||
|
SQLUSER: sa
|
||||||
|
SQLPASSWORD: Password12!
|
||||||
|
DATABASE: test
|
||||||
|
GOVERSION: 110
|
||||||
|
matrix:
|
||||||
|
- GOVERSION: 18
|
||||||
|
SQLINSTANCE: SQL2016
|
||||||
|
- GOVERSION: 110
|
||||||
|
SQLINSTANCE: SQL2016
|
||||||
|
- SQLINSTANCE: SQL2014
|
||||||
|
- SQLINSTANCE: SQL2012SP1
|
||||||
|
- SQLINSTANCE: SQL2008R2SP2
|
||||||
|
|
||||||
|
install:
|
||||||
|
- set GOROOT=c:\go%GOVERSION%
|
||||||
|
- set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH%
|
||||||
|
- go version
|
||||||
|
- go env
|
||||||
|
|
||||||
|
build_script:
|
||||||
|
- go build
|
||||||
|
|
||||||
|
before_test:
|
||||||
|
# setup SQL Server
|
||||||
|
- ps: |
|
||||||
|
$instanceName = $env:SQLINSTANCE
|
||||||
|
Start-Service "MSSQL`$$instanceName"
|
||||||
|
Start-Service "SQLBrowser"
|
||||||
|
- sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;"
|
||||||
|
- sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version"
|
||||||
|
- pip install codecov
|
||||||
|
|
||||||
|
|
||||||
|
test_script:
|
||||||
|
- go test -race -coverprofile=coverage.txt -covermode=atomic
|
||||||
|
- codecov -f coverage.txt
|
|
@ -2,12 +2,14 @@ package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type packetType uint8
|
||||||
|
|
||||||
type header struct {
|
type header struct {
|
||||||
PacketType uint8
|
PacketType packetType
|
||||||
Status uint8
|
Status uint8
|
||||||
Size uint16
|
Size uint16
|
||||||
Spid uint16
|
Spid uint16
|
||||||
|
@ -15,55 +17,84 @@ type header struct {
|
||||||
Pad uint8
|
Pad uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tdsBuffer reads and writes TDS packets of data to the transport.
|
||||||
|
// The write and read buffers are separate to make sending attn signals
|
||||||
|
// possible without locks. Currently attn signals are only sent during
|
||||||
|
// reads, not writes.
|
||||||
type tdsBuffer struct {
|
type tdsBuffer struct {
|
||||||
buf []byte
|
|
||||||
pos uint16
|
|
||||||
transport io.ReadWriteCloser
|
transport io.ReadWriteCloser
|
||||||
size uint16
|
|
||||||
|
packetSize int
|
||||||
|
|
||||||
|
// Write fields.
|
||||||
|
wbuf []byte
|
||||||
|
wpos int
|
||||||
|
wPacketSeq byte
|
||||||
|
wPacketType packetType
|
||||||
|
|
||||||
|
// Read fields.
|
||||||
|
rbuf []byte
|
||||||
|
rpos int
|
||||||
|
rsize int
|
||||||
final bool
|
final bool
|
||||||
packet_type uint8
|
rPacketType packetType
|
||||||
|
|
||||||
|
// afterFirst is assigned to right after tdsBuffer is created and
|
||||||
|
// before the first use. It is executed after the first packet is
|
||||||
|
// written and then removed.
|
||||||
afterFirst func()
|
afterFirst func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer {
|
func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
|
||||||
buf := make([]byte, bufsize)
|
return &tdsBuffer{
|
||||||
w := new(tdsBuffer)
|
packetSize: int(bufsize),
|
||||||
w.buf = buf
|
wbuf: make([]byte, 1<<16),
|
||||||
w.pos = 8
|
rbuf: make([]byte, 1<<16),
|
||||||
w.transport = transport
|
rpos: 8,
|
||||||
w.size = 0
|
transport: transport,
|
||||||
return w
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *tdsBuffer) ResizeBuffer(packetSize int) {
|
||||||
|
rw.packetSize = packetSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *tdsBuffer) PackageSize() int {
|
||||||
|
return w.packetSize
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *tdsBuffer) flush() (err error) {
|
func (w *tdsBuffer) flush() (err error) {
|
||||||
// writing packet size
|
// Write packet size.
|
||||||
binary.BigEndian.PutUint16(w.buf[2:], w.pos)
|
w.wbuf[0] = byte(w.wPacketType)
|
||||||
|
binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos))
|
||||||
|
w.wbuf[6] = w.wPacketSeq
|
||||||
|
|
||||||
// writing packet into underlying transport
|
// Write packet into underlying transport.
|
||||||
if _, err = w.transport.Write(w.buf[:w.pos]); err != nil {
|
if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// It is possible to create a whole new buffer after a flush.
|
||||||
|
// Useful for debugging. Normally reuse the buffer.
|
||||||
|
// w.wbuf = make([]byte, 1<<16)
|
||||||
|
|
||||||
// execute afterFirst hook if it is set
|
// Execute afterFirst hook if it is set.
|
||||||
if w.afterFirst != nil {
|
if w.afterFirst != nil {
|
||||||
w.afterFirst()
|
w.afterFirst()
|
||||||
w.afterFirst = nil
|
w.afterFirst = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
w.pos = 8
|
w.wpos = 8
|
||||||
// packet number
|
w.wPacketSeq++
|
||||||
w.buf[6] += 1
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *tdsBuffer) Write(p []byte) (total int, err error) {
|
func (w *tdsBuffer) Write(p []byte) (total int, err error) {
|
||||||
total = 0
|
|
||||||
for {
|
for {
|
||||||
copied := copy(w.buf[w.pos:], p)
|
copied := copy(w.wbuf[w.wpos:w.packetSize], p)
|
||||||
w.pos += uint16(copied)
|
w.wpos += copied
|
||||||
total += copied
|
total += copied
|
||||||
if copied == len(p) {
|
if copied == len(p) {
|
||||||
break
|
return
|
||||||
}
|
}
|
||||||
if err = w.flush(); err != nil {
|
if err = w.flush(); err != nil {
|
||||||
return
|
return
|
||||||
|
@ -74,66 +105,64 @@ func (w *tdsBuffer) Write(p []byte) (total int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *tdsBuffer) WriteByte(b byte) error {
|
func (w *tdsBuffer) WriteByte(b byte) error {
|
||||||
if int(w.pos) == len(w.buf) {
|
if int(w.wpos) == len(w.wbuf) {
|
||||||
if err := w.flush(); err != nil {
|
if err := w.flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.buf[w.pos] = b
|
w.wbuf[w.wpos] = b
|
||||||
w.pos += 1
|
w.wpos += 1
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *tdsBuffer) BeginPacket(packet_type byte) {
|
func (w *tdsBuffer) BeginPacket(packetType packetType) {
|
||||||
w.buf[0] = packet_type
|
w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket.
|
||||||
w.buf[1] = 0 // packet is incomplete
|
w.wpos = 8
|
||||||
w.buf[4] = 0 // spid
|
w.wPacketSeq = 1
|
||||||
w.buf[5] = 0
|
w.wPacketType = packetType
|
||||||
w.buf[6] = 1 // packet id
|
|
||||||
w.buf[7] = 0 // window
|
|
||||||
w.pos = 8
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *tdsBuffer) FinishPacket() error {
|
func (w *tdsBuffer) FinishPacket() error {
|
||||||
w.buf[1] = 1 // this is last packet
|
w.wbuf[1] = 1 // Mark this as the last packet in the message.
|
||||||
return w.flush()
|
return w.flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var headerSize = binary.Size(header{})
|
||||||
|
|
||||||
func (r *tdsBuffer) readNextPacket() error {
|
func (r *tdsBuffer) readNextPacket() error {
|
||||||
header := header{}
|
h := header{}
|
||||||
var err error
|
var err error
|
||||||
err = binary.Read(r.transport, binary.BigEndian, &header)
|
err = binary.Read(r.transport, binary.BigEndian, &h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
offset := uint16(binary.Size(header))
|
if int(h.Size) > len(r.rbuf) {
|
||||||
if int(header.Size) > len(r.buf) {
|
|
||||||
return errors.New("Invalid packet size, it is longer than buffer size")
|
return errors.New("Invalid packet size, it is longer than buffer size")
|
||||||
}
|
}
|
||||||
if int(offset) > int(header.Size) {
|
if headerSize > int(h.Size) {
|
||||||
return errors.New("Invalid packet size, it is shorter than header size")
|
return errors.New("Invalid packet size, it is shorter than header size")
|
||||||
}
|
}
|
||||||
_, err = io.ReadFull(r.transport, r.buf[offset:header.Size])
|
_, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.pos = offset
|
r.rpos = headerSize
|
||||||
r.size = header.Size
|
r.rsize = int(h.Size)
|
||||||
r.final = header.Status != 0
|
r.final = h.Status != 0
|
||||||
r.packet_type = header.PacketType
|
r.rPacketType = h.PacketType
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tdsBuffer) BeginRead() (uint8, error) {
|
func (r *tdsBuffer) BeginRead() (packetType, error) {
|
||||||
err := r.readNextPacket()
|
err := r.readNextPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return r.packet_type, nil
|
return r.rPacketType, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *tdsBuffer) ReadByte() (res byte, err error) {
|
func (r *tdsBuffer) ReadByte() (res byte, err error) {
|
||||||
if r.pos == r.size {
|
if r.rpos == r.rsize {
|
||||||
if r.final {
|
if r.final {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
@ -142,8 +171,8 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res = r.buf[r.pos]
|
res = r.rbuf[r.rpos]
|
||||||
r.pos++
|
r.rpos++
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,7 +236,7 @@ func (r *tdsBuffer) readUcs2(numchars int) string {
|
||||||
func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
|
func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
|
||||||
copied = 0
|
copied = 0
|
||||||
err = nil
|
err = nil
|
||||||
if r.pos == r.size {
|
if r.rpos == r.rsize {
|
||||||
if r.final {
|
if r.final {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
@ -216,7 +245,7 @@ func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
copied = copy(buf, r.buf[r.pos:r.size])
|
copied = copy(buf, r.rbuf[r.rpos:r.rsize])
|
||||||
r.pos += uint16(copied)
|
r.rpos += copied
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,616 @@
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bulk struct {
|
||||||
|
cn *Conn
|
||||||
|
metadata []columnStruct
|
||||||
|
bulkColumns []columnStruct
|
||||||
|
columnsName []string
|
||||||
|
tablename string
|
||||||
|
numRows int
|
||||||
|
|
||||||
|
headerSent bool
|
||||||
|
Options BulkOptions
|
||||||
|
Debug bool
|
||||||
|
}
|
||||||
|
type BulkOptions struct {
|
||||||
|
CheckConstraints bool
|
||||||
|
FireTriggers bool
|
||||||
|
KeepNulls bool
|
||||||
|
KilobytesPerBatch int
|
||||||
|
RowsPerBatch int
|
||||||
|
Order []string
|
||||||
|
Tablock bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type DataValue interface{}
|
||||||
|
|
||||||
|
func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
|
||||||
|
b := Bulk{cn: cn, tablename: table, headerSent: false, columnsName: columns}
|
||||||
|
b.Debug = false
|
||||||
|
return &b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) sendBulkCommand() (err error) {
|
||||||
|
//get table columns info
|
||||||
|
err = b.getMetadata()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//match the columns
|
||||||
|
for _, colname := range b.columnsName {
|
||||||
|
var bulkCol *columnStruct
|
||||||
|
|
||||||
|
for _, m := range b.metadata {
|
||||||
|
if m.ColName == colname {
|
||||||
|
bulkCol = &m
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if bulkCol != nil {
|
||||||
|
|
||||||
|
if bulkCol.ti.TypeId == typeUdt {
|
||||||
|
//send udt as binary
|
||||||
|
bulkCol.ti.TypeId = typeBigVarBin
|
||||||
|
}
|
||||||
|
b.bulkColumns = append(b.bulkColumns, *bulkCol)
|
||||||
|
b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//create the bulk command
|
||||||
|
|
||||||
|
//columns definitions
|
||||||
|
var col_defs bytes.Buffer
|
||||||
|
for i, col := range b.bulkColumns {
|
||||||
|
if i != 0 {
|
||||||
|
col_defs.WriteString(", ")
|
||||||
|
}
|
||||||
|
col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
|
||||||
|
}
|
||||||
|
|
||||||
|
//options
|
||||||
|
var with_opts []string
|
||||||
|
|
||||||
|
if b.Options.CheckConstraints {
|
||||||
|
with_opts = append(with_opts, "CHECK_CONSTRAINTS")
|
||||||
|
}
|
||||||
|
if b.Options.FireTriggers {
|
||||||
|
with_opts = append(with_opts, "FIRE_TRIGGERS")
|
||||||
|
}
|
||||||
|
if b.Options.KeepNulls {
|
||||||
|
with_opts = append(with_opts, "KEEP_NULLS")
|
||||||
|
}
|
||||||
|
if b.Options.KilobytesPerBatch > 0 {
|
||||||
|
with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
|
||||||
|
}
|
||||||
|
if b.Options.RowsPerBatch > 0 {
|
||||||
|
with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
|
||||||
|
}
|
||||||
|
if len(b.Options.Order) > 0 {
|
||||||
|
with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
|
||||||
|
}
|
||||||
|
if b.Options.Tablock {
|
||||||
|
with_opts = append(with_opts, "TABLOCK")
|
||||||
|
}
|
||||||
|
var with_part string
|
||||||
|
if len(with_opts) > 0 {
|
||||||
|
with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
|
||||||
|
|
||||||
|
stmt, err := b.cn.Prepare(query)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Prepare failed: %s", err.Error())
|
||||||
|
}
|
||||||
|
b.dlogf(query)
|
||||||
|
|
||||||
|
_, err = stmt.Exec(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.headerSent = true
|
||||||
|
|
||||||
|
var buf = b.cn.sess.buf
|
||||||
|
buf.BeginPacket(packBulkLoadBCP)
|
||||||
|
|
||||||
|
// send the columns metadata
|
||||||
|
columnMetadata := b.createColMetadata()
|
||||||
|
_, err = buf.Write(columnMetadata)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRow immediately writes the row to the destination table.
|
||||||
|
// The arguments are the row values in the order they were specified.
|
||||||
|
func (b *Bulk) AddRow(row []interface{}) (err error) {
|
||||||
|
if !b.headerSent {
|
||||||
|
err = b.sendBulkCommand()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(row) != len(b.bulkColumns) {
|
||||||
|
return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
|
||||||
|
len(row), len(b.bulkColumns))
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, err := b.makeRowData(row)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = b.cn.sess.buf.Write(bytes)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.numRows = b.numRows + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
buf.WriteByte(byte(tokenRow))
|
||||||
|
|
||||||
|
var logcol bytes.Buffer
|
||||||
|
for i, col := range b.bulkColumns {
|
||||||
|
|
||||||
|
if b.Debug {
|
||||||
|
logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
|
||||||
|
}
|
||||||
|
param, err := b.makeParam(row[i], col)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bulkcopy: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if col.ti.Writer == nil {
|
||||||
|
return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
|
||||||
|
col.ColName, col.ti.TypeId)
|
||||||
|
}
|
||||||
|
err = col.ti.Writer(buf, param.ti, param.buffer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("bulkcopy: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
|
||||||
|
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) Done() (rowcount int64, err error) {
|
||||||
|
if b.headerSent == false {
|
||||||
|
//no rows had been sent
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
var buf = b.cn.sess.buf
|
||||||
|
buf.WriteByte(byte(tokenDone))
|
||||||
|
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd
|
||||||
|
|
||||||
|
if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
|
||||||
|
} else {
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.FinishPacket()
|
||||||
|
|
||||||
|
tokchan := make(chan tokenStruct, 5)
|
||||||
|
go processResponse(context.Background(), b.cn.sess, tokchan, nil)
|
||||||
|
|
||||||
|
var rowCount int64
|
||||||
|
for token := range tokchan {
|
||||||
|
switch token := token.(type) {
|
||||||
|
case doneStruct:
|
||||||
|
if token.Status&doneCount != 0 {
|
||||||
|
rowCount = int64(token.RowCount)
|
||||||
|
}
|
||||||
|
if token.isError() {
|
||||||
|
return 0, token.getError()
|
||||||
|
}
|
||||||
|
case error:
|
||||||
|
return 0, b.cn.checkBadConn(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rowCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) createColMetadata() []byte {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
buf.WriteByte(byte(tokenColMetadata)) // token
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
|
||||||
|
|
||||||
|
for i, col := range b.bulkColumns {
|
||||||
|
|
||||||
|
if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0?
|
||||||
|
} else {
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
|
||||||
|
}
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
|
||||||
|
|
||||||
|
writeTypeInfo(buf, &b.bulkColumns[i].ti)
|
||||||
|
|
||||||
|
if col.ti.TypeId == typeNText ||
|
||||||
|
col.ti.TypeId == typeText ||
|
||||||
|
col.ti.TypeId == typeImage {
|
||||||
|
|
||||||
|
tablename_ucs2 := str2ucs2(b.tablename)
|
||||||
|
binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
|
||||||
|
buf.Write(tablename_ucs2)
|
||||||
|
}
|
||||||
|
colname_ucs2 := str2ucs2(col.ColName)
|
||||||
|
buf.WriteByte(uint8(len(colname_ucs2) / 2))
|
||||||
|
buf.Write(colname_ucs2)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) getMetadata() (err error) {
|
||||||
|
stmt, err := b.cn.Prepare("SET FMTONLY ON")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stmt.Exec(nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//get columns info
|
||||||
|
stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stmt2 := stmt.(*Stmt)
|
||||||
|
cols, err := stmt2.QueryMeta()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get columns info failed: %v", err.Error())
|
||||||
|
}
|
||||||
|
b.metadata = cols
|
||||||
|
|
||||||
|
if b.Debug {
|
||||||
|
for _, col := range b.metadata {
|
||||||
|
b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
|
||||||
|
col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
|
||||||
|
col.Flags, col.ti.Collation.LcidAndFlags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMeta is almost the same as mssql.Stmt.Query, but returns all the columns info.
|
||||||
|
func (s *Stmt) QueryMeta() (cols []columnStruct, err error) {
|
||||||
|
if err = s.sendQuery(nil); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokchan := make(chan tokenStruct, 5)
|
||||||
|
go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs)
|
||||||
|
s.c.clearOuts()
|
||||||
|
loop:
|
||||||
|
for tok := range tokchan {
|
||||||
|
switch token := tok.(type) {
|
||||||
|
case doneStruct:
|
||||||
|
break loop
|
||||||
|
case []columnStruct:
|
||||||
|
cols = token
|
||||||
|
break loop
|
||||||
|
case error:
|
||||||
|
return nil, s.c.checkBadConn(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cols, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
|
||||||
|
res.ti.Size = col.ti.Size
|
||||||
|
res.ti.TypeId = col.ti.TypeId
|
||||||
|
|
||||||
|
if val == nil {
|
||||||
|
res.ti.Size = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch col.ti.TypeId {
|
||||||
|
|
||||||
|
case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
|
||||||
|
var intvalue int64
|
||||||
|
|
||||||
|
switch val := val.(type) {
|
||||||
|
case int:
|
||||||
|
intvalue = int64(val)
|
||||||
|
case int32:
|
||||||
|
intvalue = int64(val)
|
||||||
|
case int64:
|
||||||
|
intvalue = val
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for int column")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res.buffer = make([]byte, res.ti.Size)
|
||||||
|
if col.ti.Size == 1 {
|
||||||
|
res.buffer[0] = byte(intvalue)
|
||||||
|
} else if col.ti.Size == 2 {
|
||||||
|
binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
|
||||||
|
} else if col.ti.Size == 4 {
|
||||||
|
binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
|
||||||
|
} else if col.ti.Size == 8 {
|
||||||
|
binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
|
||||||
|
}
|
||||||
|
case typeFlt4, typeFlt8, typeFltN:
|
||||||
|
var floatvalue float64
|
||||||
|
|
||||||
|
switch val := val.(type) {
|
||||||
|
case float32:
|
||||||
|
floatvalue = float64(val)
|
||||||
|
case float64:
|
||||||
|
floatvalue = val
|
||||||
|
case int:
|
||||||
|
floatvalue = float64(val)
|
||||||
|
case int64:
|
||||||
|
floatvalue = float64(val)
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for float column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if col.ti.Size == 4 {
|
||||||
|
res.buffer = make([]byte, 4)
|
||||||
|
binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
|
||||||
|
} else if col.ti.Size == 8 {
|
||||||
|
res.buffer = make([]byte, 8)
|
||||||
|
binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
|
||||||
|
}
|
||||||
|
case typeNVarChar, typeNText, typeNChar:
|
||||||
|
|
||||||
|
switch val := val.(type) {
|
||||||
|
case string:
|
||||||
|
res.buffer = str2ucs2(val)
|
||||||
|
case []byte:
|
||||||
|
res.buffer = val
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.ti.Size = len(res.buffer)
|
||||||
|
|
||||||
|
case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
|
||||||
|
switch val := val.(type) {
|
||||||
|
case string:
|
||||||
|
res.buffer = []byte(val)
|
||||||
|
case []byte:
|
||||||
|
res.buffer = val
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.ti.Size = len(res.buffer)
|
||||||
|
|
||||||
|
case typeBit, typeBitN:
|
||||||
|
if reflect.TypeOf(val).Kind() != reflect.Bool {
|
||||||
|
err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.ti.TypeId = typeBitN
|
||||||
|
res.ti.Size = 1
|
||||||
|
res.buffer = make([]byte, 1)
|
||||||
|
if val.(bool) {
|
||||||
|
res.buffer[0] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
case typeDateTime2N, typeDateTimeOffsetN:
|
||||||
|
switch val := val.(type) {
|
||||||
|
case time.Time:
|
||||||
|
days, ns := dateTime2(val)
|
||||||
|
ns /= int64(math.Pow10(int(col.ti.Scale)*-1) * 1000000000)
|
||||||
|
|
||||||
|
var data = make([]byte, 5)
|
||||||
|
|
||||||
|
data[0] = byte(ns)
|
||||||
|
data[1] = byte(ns >> 8)
|
||||||
|
data[2] = byte(ns >> 16)
|
||||||
|
data[3] = byte(ns >> 24)
|
||||||
|
data[4] = byte(ns >> 32)
|
||||||
|
|
||||||
|
if col.ti.Scale <= 2 {
|
||||||
|
res.ti.Size = 6
|
||||||
|
} else if col.ti.Scale <= 4 {
|
||||||
|
res.ti.Size = 7
|
||||||
|
} else {
|
||||||
|
res.ti.Size = 8
|
||||||
|
}
|
||||||
|
var buf []byte
|
||||||
|
buf = make([]byte, res.ti.Size)
|
||||||
|
copy(buf, data[0:res.ti.Size-3])
|
||||||
|
|
||||||
|
buf[res.ti.Size-3] = byte(days)
|
||||||
|
buf[res.ti.Size-2] = byte(days >> 8)
|
||||||
|
buf[res.ti.Size-1] = byte(days >> 16)
|
||||||
|
|
||||||
|
if col.ti.TypeId == typeDateTimeOffsetN {
|
||||||
|
_, offset := val.Zone()
|
||||||
|
var offsetMinute = uint16(offset / 60)
|
||||||
|
buf = append(buf, byte(offsetMinute))
|
||||||
|
buf = append(buf, byte(offsetMinute>>8))
|
||||||
|
res.ti.Size = res.ti.Size + 2
|
||||||
|
}
|
||||||
|
|
||||||
|
res.buffer = buf
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case typeDateN:
|
||||||
|
switch val := val.(type) {
|
||||||
|
case time.Time:
|
||||||
|
days, _ := dateTime2(val)
|
||||||
|
|
||||||
|
res.ti.Size = 3
|
||||||
|
res.buffer = make([]byte, 3)
|
||||||
|
res.buffer[0] = byte(days)
|
||||||
|
res.buffer[1] = byte(days >> 8)
|
||||||
|
res.buffer[2] = byte(days >> 16)
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for date column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case typeDateTime, typeDateTimeN, typeDateTim4:
|
||||||
|
switch val := val.(type) {
|
||||||
|
case time.Time:
|
||||||
|
if col.ti.Size == 4 {
|
||||||
|
res.ti.Size = 4
|
||||||
|
res.buffer = make([]byte, 4)
|
||||||
|
|
||||||
|
ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
dur := val.Sub(ref)
|
||||||
|
days := dur / (24 * time.Hour)
|
||||||
|
if days < 0 {
|
||||||
|
err = fmt.Errorf("mssql: Date %s is out of range", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mins := val.Hour()*60 + val.Minute()
|
||||||
|
|
||||||
|
binary.LittleEndian.PutUint16(res.buffer[0:2], uint16(days))
|
||||||
|
binary.LittleEndian.PutUint16(res.buffer[2:4], uint16(mins))
|
||||||
|
} else if col.ti.Size == 8 {
|
||||||
|
res.ti.Size = 8
|
||||||
|
res.buffer = make([]byte, 8)
|
||||||
|
|
||||||
|
days := divFloor(val.Unix(), 24*60*60)
|
||||||
|
//25567 - number of days since Jan 1 1900 UTC to Jan 1 1970
|
||||||
|
days = days + 25567
|
||||||
|
tm := (val.Hour()*60*60+val.Minute()*60+val.Second())*300 + int(val.Nanosecond()/10000000*3)
|
||||||
|
|
||||||
|
binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
|
||||||
|
binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("mssql: invalid size of column")
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// case typeMoney, typeMoney4, typeMoneyN:
|
||||||
|
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
|
||||||
|
var value float64
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int:
|
||||||
|
value = float64(v)
|
||||||
|
case int8:
|
||||||
|
value = float64(v)
|
||||||
|
case int16:
|
||||||
|
value = float64(v)
|
||||||
|
case int32:
|
||||||
|
value = float64(v)
|
||||||
|
case int64:
|
||||||
|
value = float64(v)
|
||||||
|
case float32:
|
||||||
|
value = float64(v)
|
||||||
|
case float64:
|
||||||
|
value = v
|
||||||
|
case string:
|
||||||
|
if value, err = strconv.ParseFloat(v, 64); err != nil {
|
||||||
|
return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return res, fmt.Errorf("unknown value for decimal: %#v", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
perc := col.ti.Prec
|
||||||
|
scale := col.ti.Scale
|
||||||
|
var dec Decimal
|
||||||
|
dec, err = Float64ToDecimalScale(value, scale)
|
||||||
|
if err != nil {
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
dec.prec = perc
|
||||||
|
|
||||||
|
var length byte
|
||||||
|
switch {
|
||||||
|
case perc <= 9:
|
||||||
|
length = 4
|
||||||
|
case perc <= 19:
|
||||||
|
length = 8
|
||||||
|
case perc <= 28:
|
||||||
|
length = 12
|
||||||
|
default:
|
||||||
|
length = 16
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, length+1)
|
||||||
|
// first byte length written by typeInfo.writer
|
||||||
|
res.ti.Size = int(length) + 1
|
||||||
|
// second byte sign
|
||||||
|
if value < 0 {
|
||||||
|
buf[0] = 0
|
||||||
|
} else {
|
||||||
|
buf[0] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
ub := dec.UnscaledBytes()
|
||||||
|
l := len(ub)
|
||||||
|
if l > int(length) {
|
||||||
|
err = fmt.Errorf("decimal out of range: %s", dec)
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
// reverse the bytes
|
||||||
|
for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
|
||||||
|
buf[i] = ub[j]
|
||||||
|
}
|
||||||
|
res.buffer = buf
|
||||||
|
case typeBigVarBin:
|
||||||
|
switch val := val.(type) {
|
||||||
|
case []byte:
|
||||||
|
res.ti.Size = len(val)
|
||||||
|
res.buffer = val
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case typeGuid:
|
||||||
|
switch val := val.(type) {
|
||||||
|
case []byte:
|
||||||
|
res.ti.Size = len(val)
|
||||||
|
res.buffer = val
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Bulk) dlogf(format string, v ...interface{}) {
|
||||||
|
if b.Debug {
|
||||||
|
b.cn.sess.log.Printf(format, v...)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type copyin struct {
|
||||||
|
cn *Conn
|
||||||
|
bulkcopy *Bulk
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type serializableBulkConfig struct {
|
||||||
|
TableName string
|
||||||
|
ColumnsName []string
|
||||||
|
Options BulkOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) OpenConnection(dsn string) (*Conn, error) {
|
||||||
|
return d.open(context.Background(), dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) prepareCopyIn(query string) (_ driver.Stmt, err error) {
|
||||||
|
config_json := query[11:]
|
||||||
|
|
||||||
|
bulkconfig := serializableBulkConfig{}
|
||||||
|
err = json.Unmarshal([]byte(config_json), &bulkconfig)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bulkcopy := c.CreateBulk(bulkconfig.TableName, bulkconfig.ColumnsName)
|
||||||
|
bulkcopy.Options = bulkconfig.Options
|
||||||
|
|
||||||
|
ci := ©in{
|
||||||
|
cn: c,
|
||||||
|
bulkcopy: bulkcopy,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ci, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyIn(table string, options BulkOptions, columns ...string) string {
|
||||||
|
bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
|
||||||
|
|
||||||
|
config_json, err := json.Marshal(bulkconfig)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := "INSERTBULK " + string(config_json)
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ci *copyin) NumInput() int {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
|
||||||
|
return nil, errors.New("ErrNotSupported")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
|
||||||
|
if ci.closed {
|
||||||
|
return nil, errors.New("errCopyInClosed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(v) == 0 {
|
||||||
|
rowCount, err := ci.bulkcopy.Done()
|
||||||
|
ci.closed = true
|
||||||
|
return driver.RowsAffected(rowCount), err
|
||||||
|
}
|
||||||
|
|
||||||
|
t := make([]interface{}, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
t[i] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ci.bulkcopy.AddRow(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return driver.RowsAffected(0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ci *copyin) Close() (err error) {
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,39 +0,0 @@
|
||||||
package mssql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// http://msdn.microsoft.com/en-us/library/dd340437.aspx
|
|
||||||
|
|
||||||
type collation struct {
|
|
||||||
lcidAndFlags uint32
|
|
||||||
sortId uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c collation) getLcid() uint32 {
|
|
||||||
return c.lcidAndFlags & 0x000fffff
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c collation) getFlags() uint32 {
|
|
||||||
return (c.lcidAndFlags & 0x0ff00000) >> 20
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c collation) getVersion() uint32 {
|
|
||||||
return (c.lcidAndFlags & 0xf0000000) >> 28
|
|
||||||
}
|
|
||||||
|
|
||||||
func readCollation(r *tdsBuffer) (res collation) {
|
|
||||||
res.lcidAndFlags = r.uint32()
|
|
||||||
res.sortId = r.byte()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeCollation(w io.Writer, col collation) (err error) {
|
|
||||||
if err = binary.Write(w, binary.LittleEndian, col.lcidAndFlags); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = binary.Write(w, binary.LittleEndian, col.sortId)
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -32,7 +32,13 @@ func (d Decimal) ToFloat64() float64 {
|
||||||
return val
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const autoScale = 100
|
||||||
|
|
||||||
func Float64ToDecimal(f float64) (Decimal, error) {
|
func Float64ToDecimal(f float64) (Decimal, error) {
|
||||||
|
return Float64ToDecimalScale(f, autoScale)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64ToDecimalScale(f float64, scale uint8) (Decimal, error) {
|
||||||
var dec Decimal
|
var dec Decimal
|
||||||
if math.IsNaN(f) {
|
if math.IsNaN(f) {
|
||||||
return dec, errors.New("NaN")
|
return dec, errors.New("NaN")
|
||||||
|
@ -49,10 +55,10 @@ func Float64ToDecimal(f float64) (Decimal, error) {
|
||||||
}
|
}
|
||||||
dec.prec = 20
|
dec.prec = 20
|
||||||
var integer float64
|
var integer float64
|
||||||
for dec.scale = 0; dec.scale <= 20; dec.scale++ {
|
for dec.scale = 0; dec.scale <= scale; dec.scale++ {
|
||||||
integer = f * scaletblflt64[dec.scale]
|
integer = f * scaletblflt64[dec.scale]
|
||||||
_, frac := math.Modf(integer)
|
_, frac := math.Modf(integer)
|
||||||
if frac == 0 {
|
if frac == 0 && scale == autoScale {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,7 +79,7 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Decimal) Bytes() []byte {
|
func (d Decimal) BigInt() big.Int {
|
||||||
bytes := make([]byte, 16)
|
bytes := make([]byte, 16)
|
||||||
binary.BigEndian.PutUint32(bytes[0:4], d.integer[3])
|
binary.BigEndian.PutUint32(bytes[0:4], d.integer[3])
|
||||||
binary.BigEndian.PutUint32(bytes[4:8], d.integer[2])
|
binary.BigEndian.PutUint32(bytes[4:8], d.integer[2])
|
||||||
|
@ -84,9 +90,19 @@ func (d Decimal) Bytes() []byte {
|
||||||
if !d.positive {
|
if !d.positive {
|
||||||
x.Neg(&x)
|
x.Neg(&x)
|
||||||
}
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Decimal) Bytes() []byte {
|
||||||
|
x := d.BigInt()
|
||||||
return scaleBytes(x.String(), d.scale)
|
return scaleBytes(x.String(), d.scale)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d Decimal) UnscaledBytes() []byte {
|
||||||
|
x := d.BigInt()
|
||||||
|
return x.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
func scaleBytes(s string, scale uint8) []byte {
|
func scaleBytes(s string, scale uint8) []byte {
|
||||||
z := make([]byte, 0, len(s)+1)
|
z := make([]byte, 0, len(s)+1)
|
||||||
if s[0] == '-' || s[0] == '+' {
|
if s[0] == '-' || s[0] == '+' {
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
// package mssql implements the TDS protocol used to connect to MS SQL Server (sqlserver)
|
||||||
|
// database servers.
|
||||||
|
//
|
||||||
|
// This package registers two drivers:
|
||||||
|
// sqlserver: uses native "@" parameter placeholder names and does no pre-processing.
|
||||||
|
// mssql: expects identifiers to be prefixed with ":" and pre-processes queries.
|
||||||
|
//
|
||||||
|
// If the ordinal position is used for query parameters, identifiers will be named
|
||||||
|
// "@p1", "@p2", ... "@pN".
|
||||||
|
//
|
||||||
|
// Please refer to the README for the format of the DSN.
|
||||||
|
package mssql
|
|
@ -1,14 +1,14 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
type charsetMap struct {
|
type charsetMap struct {
|
||||||
sb [256]rune // single byte runes, -1 for a double byte character lead byte
|
sb [256]rune // single byte runes, -1 for a double byte character lead byte
|
||||||
db map[int]rune // double byte runes
|
db map[int]rune // double byte runes
|
||||||
}
|
}
|
||||||
|
|
||||||
func collation2charset(col collation) *charsetMap {
|
func collation2charset(col Collation) *charsetMap {
|
||||||
// http://msdn.microsoft.com/en-us/library/ms144250.aspx
|
// http://msdn.microsoft.com/en-us/library/ms144250.aspx
|
||||||
// http://msdn.microsoft.com/en-us/library/ms144250(v=sql.105).aspx
|
// http://msdn.microsoft.com/en-us/library/ms144250(v=sql.105).aspx
|
||||||
switch col.sortId {
|
switch col.SortId {
|
||||||
case 30, 31, 32, 33, 34:
|
case 30, 31, 32, 33, 34:
|
||||||
return cp437
|
return cp437
|
||||||
case 40, 41, 42, 44, 49, 55, 56, 57, 58, 59, 60, 61:
|
case 40, 41, 42, 44, 49, 55, 56, 57, 58, 59, 60, 61:
|
||||||
|
@ -86,7 +86,7 @@ func collation2charset(col collation) *charsetMap {
|
||||||
return cp1252
|
return cp1252
|
||||||
}
|
}
|
||||||
|
|
||||||
func charset2utf8(col collation, s []byte) string {
|
func CharsetToUTF8(col Collation, s []byte) string {
|
||||||
cm := collation2charset(col)
|
cm := collation2charset(col)
|
||||||
if cm == nil {
|
if cm == nil {
|
||||||
return string(s)
|
return string(s)
|
|
@ -0,0 +1,20 @@
|
||||||
|
package cp
|
||||||
|
|
||||||
|
// http://msdn.microsoft.com/en-us/library/dd340437.aspx
|
||||||
|
|
||||||
|
type Collation struct {
|
||||||
|
LcidAndFlags uint32
|
||||||
|
SortId uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Collation) getLcid() uint32 {
|
||||||
|
return c.LcidAndFlags & 0x000fffff
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Collation) getFlags() uint32 {
|
||||||
|
return (c.LcidAndFlags & 0x0ff00000) >> 20
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Collation) getVersion() uint32 {
|
||||||
|
return (c.LcidAndFlags & 0xf0000000) >> 28
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1250 *charsetMap = &charsetMap{
|
var cp1250 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1251 *charsetMap = &charsetMap{
|
var cp1251 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1252 *charsetMap = &charsetMap{
|
var cp1252 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1253 *charsetMap = &charsetMap{
|
var cp1253 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1254 *charsetMap = &charsetMap{
|
var cp1254 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1255 *charsetMap = &charsetMap{
|
var cp1255 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1256 *charsetMap = &charsetMap{
|
var cp1256 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1257 *charsetMap = &charsetMap{
|
var cp1257 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp1258 *charsetMap = &charsetMap{
|
var cp1258 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp437 *charsetMap = &charsetMap{
|
var cp437 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp850 *charsetMap = &charsetMap{
|
var cp850 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp874 *charsetMap = &charsetMap{
|
var cp874 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp932 *charsetMap = &charsetMap{
|
var cp932 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp936 *charsetMap = &charsetMap{
|
var cp936 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp949 *charsetMap = &charsetMap{
|
var cp949 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -1,4 +1,4 @@
|
||||||
package mssql
|
package cp
|
||||||
|
|
||||||
var cp950 *charsetMap = &charsetMap{
|
var cp950 *charsetMap = &charsetMap{
|
||||||
sb: [256]rune{
|
sb: [256]rune{
|
|
@ -4,19 +4,26 @@ import (
|
||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Logger log.Logger
|
type Logger interface {
|
||||||
|
Printf(format string, v ...interface{})
|
||||||
|
Println(v ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
func (logger *Logger) Printf(format string, v ...interface{}) {
|
type optionalLogger struct {
|
||||||
if logger != nil {
|
logger Logger
|
||||||
(*log.Logger)(logger).Printf(format, v...)
|
}
|
||||||
|
|
||||||
|
func (o optionalLogger) Printf(format string, v ...interface{}) {
|
||||||
|
if o.logger != nil {
|
||||||
|
o.logger.Printf(format, v...)
|
||||||
} else {
|
} else {
|
||||||
log.Printf(format, v...)
|
log.Printf(format, v...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (logger *Logger) Println(v ...interface{}) {
|
func (o optionalLogger) Println(v ...interface{}) {
|
||||||
if logger != nil {
|
if o.logger != nil {
|
||||||
(*log.Logger)(logger).Println(v...)
|
o.logger.Println(v...)
|
||||||
} else {
|
} else {
|
||||||
log.Println(v...)
|
log.Println(v...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,122 +1,269 @@
|
||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var driverInstance = &Driver{processQueryText: true}
|
||||||
|
var driverInstanceNoProcess = &Driver{processQueryText: false}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
sql.Register("mssql", &MssqlDriver{})
|
sql.Register("mssql", driverInstance)
|
||||||
|
sql.Register("sqlserver", driverInstanceNoProcess)
|
||||||
|
createDialer = func(p *connectParams) dialer {
|
||||||
|
return tcpDialer{&net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type MssqlDriver struct {
|
// Abstract the dialer for testing and for non-TCP based connections.
|
||||||
log *log.Logger
|
type dialer interface {
|
||||||
|
Dial(ctx context.Context, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *MssqlDriver) SetLogger(logger *log.Logger) {
|
var createDialer func(p *connectParams) dialer
|
||||||
d.log = logger
|
|
||||||
|
type tcpDialer struct {
|
||||||
|
nd *net.Dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckBadConn(err error) error {
|
func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
if err == io.EOF {
|
return d.nd.DialContext(ctx, "tcp", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Driver struct {
|
||||||
|
log optionalLogger
|
||||||
|
|
||||||
|
processQueryText bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenConnector opens a new connector. Useful to dial with a context.
|
||||||
|
func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
|
||||||
|
params, err := parseConnectParams(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Connector{
|
||||||
|
params: params,
|
||||||
|
driver: d,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Open(dsn string) (driver.Conn, error) {
|
||||||
|
return d.open(context.Background(), dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connector holds the parsed DSN and is ready to make a new connection
|
||||||
|
// at any time.
|
||||||
|
//
|
||||||
|
// In the future, settings that cannot be passed through a string DSN
|
||||||
|
// may be set directly on the connector.
|
||||||
|
type Connector struct {
|
||||||
|
params connectParams
|
||||||
|
driver *Driver
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the server and return a TDS connection.
|
||||||
|
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||||
|
return c.driver.connect(ctx, c.params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Driver underlying the Connector.
|
||||||
|
func (c *Connector) Driver() driver.Driver {
|
||||||
|
return c.driver
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLogger(logger Logger) {
|
||||||
|
driverInstance.SetLogger(logger)
|
||||||
|
driverInstanceNoProcess.SetLogger(logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) SetLogger(logger Logger) {
|
||||||
|
d.log = optionalLogger{logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
sess *tdsSession
|
||||||
|
transactionCtx context.Context
|
||||||
|
|
||||||
|
processQueryText bool
|
||||||
|
connectionGood bool
|
||||||
|
|
||||||
|
outs map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) checkBadConn(err error) error {
|
||||||
|
// this is a hack to address Issue #275
|
||||||
|
// we set connectionGood flag to false if
|
||||||
|
// error indicates that connection is not usable
|
||||||
|
// but we return actual error instead of ErrBadConn
|
||||||
|
// this will cause connection to stay in a pool
|
||||||
|
// but next request to this connection will return ErrBadConn
|
||||||
|
|
||||||
|
// it might be possible to revise this hack after
|
||||||
|
// https://github.com/golang/go/issues/20807
|
||||||
|
// is implemented
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case io.EOF:
|
||||||
return driver.ErrBadConn
|
return driver.ErrBadConn
|
||||||
|
case driver.ErrBadConn:
|
||||||
|
// It is an internal programming error if driver.ErrBadConn
|
||||||
|
// is ever passed to this function. driver.ErrBadConn should
|
||||||
|
// only ever be returned in response to a *mssql.Conn.connectionGood == false
|
||||||
|
// check in the external facing API.
|
||||||
|
panic("driver.ErrBadConn in checkBadConn. This should not happen.")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch e := err.(type) {
|
switch err.(type) {
|
||||||
case net.Error:
|
case net.Error:
|
||||||
if e.Timeout() {
|
c.connectionGood = false
|
||||||
return e
|
return err
|
||||||
}
|
case StreamError:
|
||||||
return driver.ErrBadConn
|
c.connectionGood = false
|
||||||
|
return err
|
||||||
default:
|
default:
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type MssqlConn struct {
|
func (c *Conn) clearOuts() {
|
||||||
sess *tdsSession
|
c.outs = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *MssqlConn) Commit() error {
|
func (c *Conn) simpleProcessResp(ctx context.Context) error {
|
||||||
|
tokchan := make(chan tokenStruct, 5)
|
||||||
|
go processResponse(ctx, c.sess, tokchan, c.outs)
|
||||||
|
c.clearOuts()
|
||||||
|
for tok := range tokchan {
|
||||||
|
switch token := tok.(type) {
|
||||||
|
case doneStruct:
|
||||||
|
if token.isError() {
|
||||||
|
return c.checkBadConn(token.getError())
|
||||||
|
}
|
||||||
|
case error:
|
||||||
|
return c.checkBadConn(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Commit() error {
|
||||||
|
if !c.connectionGood {
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
|
if err := c.sendCommitRequest(); err != nil {
|
||||||
|
return c.checkBadConn(err)
|
||||||
|
}
|
||||||
|
return c.simpleProcessResp(c.transactionCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendCommitRequest() error {
|
||||||
headers := []headerStruct{
|
headers := []headerStruct{
|
||||||
{hdrtype: dataStmHdrTransDescr,
|
{hdrtype: dataStmHdrTransDescr,
|
||||||
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
||||||
}
|
}
|
||||||
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
||||||
return err
|
if c.sess.logFlags&logErrors != 0 {
|
||||||
}
|
c.sess.log.Printf("Failed to send CommitXact with %v", err)
|
||||||
|
|
||||||
tokchan := make(chan tokenStruct, 5)
|
|
||||||
go processResponse(c.sess, tokchan)
|
|
||||||
for tok := range tokchan {
|
|
||||||
switch token := tok.(type) {
|
|
||||||
case error:
|
|
||||||
return token
|
|
||||||
}
|
}
|
||||||
|
c.connectionGood = false
|
||||||
|
return fmt.Errorf("Faild to send CommitXact: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *MssqlConn) Rollback() error {
|
func (c *Conn) Rollback() error {
|
||||||
|
if !c.connectionGood {
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
|
if err := c.sendRollbackRequest(); err != nil {
|
||||||
|
return c.checkBadConn(err)
|
||||||
|
}
|
||||||
|
return c.simpleProcessResp(c.transactionCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendRollbackRequest() error {
|
||||||
headers := []headerStruct{
|
headers := []headerStruct{
|
||||||
{hdrtype: dataStmHdrTransDescr,
|
{hdrtype: dataStmHdrTransDescr,
|
||||||
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
data: transDescrHdr{c.sess.tranid, 1}.pack()},
|
||||||
}
|
}
|
||||||
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
|
||||||
return err
|
if c.sess.logFlags&logErrors != 0 {
|
||||||
}
|
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
|
||||||
|
|
||||||
tokchan := make(chan tokenStruct, 5)
|
|
||||||
go processResponse(c.sess, tokchan)
|
|
||||||
for tok := range tokchan {
|
|
||||||
switch token := tok.(type) {
|
|
||||||
case error:
|
|
||||||
return token
|
|
||||||
}
|
}
|
||||||
|
c.connectionGood = false
|
||||||
|
return fmt.Errorf("Failed to send RollbackXact: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *MssqlConn) Begin() (driver.Tx, error) {
|
func (c *Conn) Begin() (driver.Tx, error) {
|
||||||
|
return c.begin(context.Background(), isolationUseCurrent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) {
|
||||||
|
if !c.connectionGood {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
err = c.sendBeginRequest(ctx, tdsIsolation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, c.checkBadConn(err)
|
||||||
|
}
|
||||||
|
tx, err = c.processBeginResponse(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, c.checkBadConn(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
|
||||||
|
c.transactionCtx = ctx
|
||||||
headers := []headerStruct{
|
headers := []headerStruct{
|
||||||
{hdrtype: dataStmHdrTransDescr,
|
{hdrtype: dataStmHdrTransDescr,
|
||||||
data: transDescrHdr{0, 1}.pack()},
|
data: transDescrHdr{0, 1}.pack()},
|
||||||
}
|
}
|
||||||
if err := sendBeginXact(c.sess.buf, headers, 0, ""); err != nil {
|
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
|
||||||
return nil, CheckBadConn(err)
|
if c.sess.logFlags&logErrors != 0 {
|
||||||
|
c.sess.log.Printf("Failed to send BeginXact with %v", err)
|
||||||
}
|
}
|
||||||
tokchan := make(chan tokenStruct, 5)
|
c.connectionGood = false
|
||||||
go processResponse(c.sess, tokchan)
|
return fmt.Errorf("Failed to send BiginXant: %v", err)
|
||||||
for tok := range tokchan {
|
|
||||||
switch token := tok.(type) {
|
|
||||||
case error:
|
|
||||||
if c.sess.tranid != 0 {
|
|
||||||
return nil, token
|
|
||||||
}
|
|
||||||
return nil, CheckBadConn(token)
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
|
||||||
|
if err := c.simpleProcessResp(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
// successful BEGINXACT request will return sess.tranid
|
// successful BEGINXACT request will return sess.tranid
|
||||||
// for started transaction
|
// for started transaction
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
|
func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
|
||||||
params, err := parseConnectParams(dsn)
|
params, err := parseConnectParams(dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return d.connect(ctx, params)
|
||||||
|
}
|
||||||
|
|
||||||
sess, err := connect(params)
|
// connect to the server, using the provided context for dialing only.
|
||||||
|
func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) {
|
||||||
|
sess, err := connect(ctx, d.log, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// main server failed, try fail-over partner
|
// main server failed, try fail-over partner
|
||||||
if params.failOverPartner == "" {
|
if params.failOverPartner == "" {
|
||||||
|
@ -128,24 +275,29 @@ func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
params.port = params.failOverPort
|
params.port = params.failOverPort
|
||||||
}
|
}
|
||||||
|
|
||||||
sess, err = connect(params)
|
sess, err = connect(ctx, d.log, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// fail-over partner also failed, now fail
|
// fail-over partner also failed, now fail
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &MssqlConn{sess}
|
conn := &Conn{
|
||||||
conn.sess.log = (*Logger)(d.log)
|
sess: sess,
|
||||||
|
transactionCtx: context.Background(),
|
||||||
|
processQueryText: d.processQueryText,
|
||||||
|
connectionGood: true,
|
||||||
|
}
|
||||||
|
conn.sess.log = d.log
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *MssqlConn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return c.sess.buf.transport.Close()
|
return c.sess.buf.transport.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
type MssqlStmt struct {
|
type Stmt struct {
|
||||||
c *MssqlConn
|
c *Conn
|
||||||
query string
|
query string
|
||||||
paramCount int
|
paramCount int
|
||||||
notifSub *queryNotifSub
|
notifSub *queryNotifSub
|
||||||
|
@ -157,16 +309,30 @@ type queryNotifSub struct {
|
||||||
timeout uint32
|
timeout uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
|
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
||||||
q, paramCount := parseParams(query)
|
if !c.connectionGood {
|
||||||
return &MssqlStmt{c, q, paramCount, nil}, nil
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
|
||||||
|
return c.prepareCopyIn(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.prepareContext(context.Background(), query)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MssqlStmt) Close() error {
|
func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
|
||||||
|
paramCount := -1
|
||||||
|
if c.processQueryText {
|
||||||
|
query, paramCount = parseParams(query)
|
||||||
|
}
|
||||||
|
return &Stmt{c, query, paramCount, nil}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
|
func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) {
|
||||||
to := uint32(timeout / time.Second)
|
to := uint32(timeout / time.Second)
|
||||||
if to < 1 {
|
if to < 1 {
|
||||||
to = 1
|
to = 1
|
||||||
|
@ -174,183 +340,273 @@ func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Durati
|
||||||
s.notifSub = &queryNotifSub{id, options, to}
|
s.notifSub = &queryNotifSub{id, options, to}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MssqlStmt) NumInput() int {
|
func (s *Stmt) NumInput() int {
|
||||||
return s.paramCount
|
return s.paramCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MssqlStmt) sendQuery(args []driver.Value) (err error) {
|
func (s *Stmt) sendQuery(args []namedValue) (err error) {
|
||||||
headers := []headerStruct{
|
headers := []headerStruct{
|
||||||
{hdrtype: dataStmHdrTransDescr,
|
{hdrtype: dataStmHdrTransDescr,
|
||||||
data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
|
data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.notifSub != nil {
|
if s.notifSub != nil {
|
||||||
headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,
|
headers = append(headers,
|
||||||
data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
|
headerStruct{
|
||||||
|
hdrtype: dataStmHdrQueryNotif,
|
||||||
|
data: queryNotifHdr{
|
||||||
|
s.notifSub.msgText,
|
||||||
|
s.notifSub.options,
|
||||||
|
s.notifSub.timeout,
|
||||||
|
}.pack(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args) != s.paramCount {
|
// no need to check number of parameters here, it is checked by database/sql
|
||||||
return errors.New(fmt.Sprintf("sql: expected %d parameters, got %d", s.paramCount, len(args)))
|
|
||||||
}
|
|
||||||
if s.c.sess.logFlags&logSQL != 0 {
|
if s.c.sess.logFlags&logSQL != 0 {
|
||||||
s.c.sess.log.Println(s.query)
|
s.c.sess.log.Println(s.query)
|
||||||
}
|
}
|
||||||
if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
|
if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
|
||||||
for i := 0; i < len(args); i++ {
|
for i := 0; i < len(args); i++ {
|
||||||
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])
|
if len(args[i].Name) > 0 {
|
||||||
|
s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
|
||||||
|
} else {
|
||||||
|
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
|
if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
|
||||||
if s.c.sess.tranid != 0 {
|
if s.c.sess.logFlags&logErrors != 0 {
|
||||||
return err
|
s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
|
||||||
}
|
}
|
||||||
return CheckBadConn(err)
|
s.c.connectionGood = false
|
||||||
|
return fmt.Errorf("failed to send SQL Batch: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
params := make([]Param, len(args)+2)
|
proc := Sp_ExecuteSql
|
||||||
decls := make([]string, len(args))
|
var params []Param
|
||||||
params[0], err = s.makeParam(s.query)
|
if isProc(s.query) {
|
||||||
|
proc.name = s.query
|
||||||
|
params, _, err = s.makeRPCParams(args, 0)
|
||||||
|
} else {
|
||||||
|
var decls []string
|
||||||
|
params, decls, err = s.makeRPCParams(args, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for i, val := range args {
|
params[0] = makeStrParam(s.query)
|
||||||
params[i+2], err = s.makeParam(val)
|
params[1] = makeStrParam(strings.Join(decls, ","))
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
name := fmt.Sprintf("@p%d", i+1)
|
if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil {
|
||||||
params[i+2].Name = name
|
if s.c.sess.logFlags&logErrors != 0 {
|
||||||
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
|
s.c.sess.log.Printf("Failed to send Rpc with %v", err)
|
||||||
}
|
}
|
||||||
params[1], err = s.makeParam(strings.Join(decls, ","))
|
s.c.connectionGood = false
|
||||||
if err != nil {
|
return fmt.Errorf("Failed to send RPC: %v", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
|
|
||||||
if s.c.sess.tranid != 0 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return CheckBadConn(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MssqlStmt) Query(args []driver.Value) (res driver.Rows, err error) {
|
// isProc takes the query text in s and determines if it is a stored proc name
|
||||||
if err = s.sendQuery(args); err != nil {
|
// or SQL text.
|
||||||
return
|
func isProc(s string) bool {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
if s[0] == '[' && s[len(s)-1] == ']' && strings.ContainsAny(s, "\n\r") == false {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !strings.ContainsAny(s, " \t\n\r;")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
|
||||||
|
var err error
|
||||||
|
params := make([]Param, len(args)+offset)
|
||||||
|
decls := make([]string, len(args))
|
||||||
|
for i, val := range args {
|
||||||
|
params[i+offset], err = s.makeParam(val.Value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
var name string
|
||||||
|
if len(val.Name) > 0 {
|
||||||
|
name = "@" + val.Name
|
||||||
|
} else {
|
||||||
|
name = fmt.Sprintf("@p%d", val.Ordinal)
|
||||||
|
}
|
||||||
|
params[i+offset].Name = name
|
||||||
|
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
|
||||||
|
}
|
||||||
|
return params, decls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type namedValue struct {
|
||||||
|
Name string
|
||||||
|
Ordinal int
|
||||||
|
Value driver.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertOldArgs(args []driver.Value) []namedValue {
|
||||||
|
list := make([]namedValue, len(args))
|
||||||
|
for i, v := range args {
|
||||||
|
list[i] = namedValue{
|
||||||
|
Ordinal: i + 1,
|
||||||
|
Value: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
|
return s.queryContext(context.Background(), convertOldArgs(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) {
|
||||||
|
if !s.c.connectionGood {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
if err = s.sendQuery(args); err != nil {
|
||||||
|
return nil, s.c.checkBadConn(err)
|
||||||
|
}
|
||||||
|
return s.processQueryResponse(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
|
||||||
tokchan := make(chan tokenStruct, 5)
|
tokchan := make(chan tokenStruct, 5)
|
||||||
go processResponse(s.c.sess, tokchan)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
|
||||||
|
s.c.clearOuts()
|
||||||
// process metadata
|
// process metadata
|
||||||
var cols []string
|
var cols []columnStruct
|
||||||
loop:
|
loop:
|
||||||
for tok := range tokchan {
|
for tok := range tokchan {
|
||||||
switch token := tok.(type) {
|
switch token := tok.(type) {
|
||||||
// by ignoring DONE token we effectively
|
// By ignoring DONE token we effectively
|
||||||
// skip empty result-sets
|
// skip empty result-sets.
|
||||||
// this improves results in queryes like that:
|
// This improves results in queries like that:
|
||||||
// set nocount on; select 1
|
// set nocount on; select 1
|
||||||
// see TestIgnoreEmptyResults test
|
// see TestIgnoreEmptyResults test
|
||||||
//case doneStruct:
|
//case doneStruct:
|
||||||
//break loop
|
//break loop
|
||||||
case []columnStruct:
|
case []columnStruct:
|
||||||
cols = make([]string, len(token))
|
cols = token
|
||||||
for i, col := range token {
|
|
||||||
cols[i] = col.ColName
|
|
||||||
}
|
|
||||||
break loop
|
break loop
|
||||||
|
case doneStruct:
|
||||||
|
if token.isError() {
|
||||||
|
return nil, s.c.checkBadConn(token.getError())
|
||||||
|
}
|
||||||
case error:
|
case error:
|
||||||
if s.c.sess.tranid != 0 {
|
return nil, s.c.checkBadConn(token)
|
||||||
return nil, token
|
|
||||||
}
|
|
||||||
return nil, CheckBadConn(token)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols}, nil
|
res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MssqlStmt) Exec(args []driver.Value) (res driver.Result, err error) {
|
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
if err = s.sendQuery(args); err != nil {
|
return s.exec(context.Background(), convertOldArgs(args))
|
||||||
return
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) {
|
||||||
|
if !s.c.connectionGood {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
if err = s.sendQuery(args); err != nil {
|
||||||
|
return nil, s.c.checkBadConn(err)
|
||||||
|
}
|
||||||
|
if res, err = s.processExec(ctx); err != nil {
|
||||||
|
return nil, s.c.checkBadConn(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
|
||||||
tokchan := make(chan tokenStruct, 5)
|
tokchan := make(chan tokenStruct, 5)
|
||||||
go processResponse(s.c.sess, tokchan)
|
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
|
||||||
|
s.c.clearOuts()
|
||||||
var rowCount int64
|
var rowCount int64
|
||||||
for token := range tokchan {
|
for token := range tokchan {
|
||||||
switch token := token.(type) {
|
switch token := token.(type) {
|
||||||
case doneInProcStruct:
|
case doneInProcStruct:
|
||||||
if token.Status&doneCount != 0 {
|
if token.Status&doneCount != 0 {
|
||||||
rowCount = int64(token.RowCount)
|
rowCount += int64(token.RowCount)
|
||||||
}
|
}
|
||||||
case doneStruct:
|
case doneStruct:
|
||||||
if token.Status&doneCount != 0 {
|
if token.Status&doneCount != 0 {
|
||||||
rowCount = int64(token.RowCount)
|
rowCount += int64(token.RowCount)
|
||||||
|
}
|
||||||
|
if token.isError() {
|
||||||
|
return nil, token.getError()
|
||||||
}
|
}
|
||||||
case error:
|
case error:
|
||||||
if s.c.sess.logFlags&logErrors != 0 {
|
|
||||||
s.c.sess.log.Println("got error:", token)
|
|
||||||
}
|
|
||||||
if s.c.sess.tranid != 0 {
|
|
||||||
return nil, token
|
return nil, token
|
||||||
}
|
}
|
||||||
return nil, CheckBadConn(token)
|
|
||||||
}
|
}
|
||||||
}
|
return &Result{s.c, rowCount}, nil
|
||||||
return &MssqlResult{s.c, rowCount}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MssqlRows struct {
|
type Rows struct {
|
||||||
sess *tdsSession
|
stmt *Stmt
|
||||||
cols []string
|
cols []columnStruct
|
||||||
tokchan chan tokenStruct
|
tokchan chan tokenStruct
|
||||||
|
|
||||||
nextCols []string
|
nextCols []columnStruct
|
||||||
|
|
||||||
|
cancel func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *MssqlRows) Close() error {
|
func (rc *Rows) Close() error {
|
||||||
|
rc.cancel()
|
||||||
for _ = range rc.tokchan {
|
for _ = range rc.tokchan {
|
||||||
}
|
}
|
||||||
rc.tokchan = nil
|
rc.tokchan = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *MssqlRows) Columns() (res []string) {
|
func (rc *Rows) Columns() (res []string) {
|
||||||
return rc.cols
|
res = make([]string, len(rc.cols))
|
||||||
|
for i, col := range rc.cols {
|
||||||
|
res[i] = col.ColName
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *MssqlRows) Next(dest []driver.Value) (err error) {
|
func (rc *Rows) Next(dest []driver.Value) error {
|
||||||
|
if !rc.stmt.c.connectionGood {
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
if rc.nextCols != nil {
|
if rc.nextCols != nil {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
for tok := range rc.tokchan {
|
for tok := range rc.tokchan {
|
||||||
switch tokdata := tok.(type) {
|
switch tokdata := tok.(type) {
|
||||||
case []columnStruct:
|
case []columnStruct:
|
||||||
cols := make([]string, len(tokdata))
|
rc.nextCols = tokdata
|
||||||
for i, col := range tokdata {
|
|
||||||
cols[i] = col.ColName
|
|
||||||
}
|
|
||||||
rc.nextCols = cols
|
|
||||||
return io.EOF
|
return io.EOF
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
for i := range dest {
|
for i := range dest {
|
||||||
dest[i] = tokdata[i]
|
dest[i] = tokdata[i]
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
case doneStruct:
|
||||||
|
if tokdata.isError() {
|
||||||
|
return rc.stmt.c.checkBadConn(tokdata.getError())
|
||||||
|
}
|
||||||
case error:
|
case error:
|
||||||
return tokdata
|
return rc.stmt.c.checkBadConn(tokdata)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *MssqlRows) HasNextResultSet() bool {
|
func (rc *Rows) HasNextResultSet() bool {
|
||||||
return rc.nextCols != nil
|
return rc.nextCols != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *MssqlRows) NextResultSet() error {
|
func (rc *Rows) NextResultSet() error {
|
||||||
rc.cols = rc.nextCols
|
rc.cols = rc.nextCols
|
||||||
rc.nextCols = nil
|
rc.nextCols = nil
|
||||||
if rc.cols == nil {
|
if rc.cols == nil {
|
||||||
|
@ -359,11 +615,69 @@ func (rc *MssqlRows) NextResultSet() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
|
// It should return
|
||||||
if val == nil {
|
// the value type that can be used to scan types into. For example, the database
|
||||||
|
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
|
||||||
|
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
|
||||||
|
return makeGoLangScanType(r.cols[index].ti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
|
||||||
|
// database system type name without the length. Type names should be uppercase.
|
||||||
|
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
|
||||||
|
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
|
||||||
|
// "TIMESTAMP".
|
||||||
|
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
||||||
|
return makeGoLangTypeName(r.cols[index].ti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowsColumnTypeLength may be implemented by Rows. It should return the length
|
||||||
|
// of the column type if the column is a variable length type. If the column is
|
||||||
|
// not a variable length type ok should return false.
|
||||||
|
// If length is not limited other than system limits, it should return math.MaxInt64.
|
||||||
|
// The following are examples of returned values for various types:
|
||||||
|
// TEXT (math.MaxInt64, true)
|
||||||
|
// varchar(10) (10, true)
|
||||||
|
// nvarchar(10) (10, true)
|
||||||
|
// decimal (0, false)
|
||||||
|
// int (0, false)
|
||||||
|
// bytea(30) (30, true)
|
||||||
|
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
|
||||||
|
return makeGoLangTypeLength(r.cols[index].ti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// It should return
|
||||||
|
// the precision and scale for decimal types. If not applicable, ok should be false.
|
||||||
|
// The following are examples of returned values for various types:
|
||||||
|
// decimal(38, 4) (38, 4, true)
|
||||||
|
// int (0, 0, false)
|
||||||
|
// decimal (math.MaxInt64, math.MaxInt64, true)
|
||||||
|
func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
|
||||||
|
return makeGoLangTypePrecisionScale(r.cols[index].ti)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The nullable value should
|
||||||
|
// be true if it is known the column may be null, or false if the column is known
|
||||||
|
// to be not nullable.
|
||||||
|
// If the column nullability is unknown, ok should be false.
|
||||||
|
func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
||||||
|
nullable = r.cols[index].Flags&colFlagNullable != 0
|
||||||
|
ok = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeStrParam(val string) (res Param) {
|
||||||
res.ti.TypeId = typeNVarChar
|
res.ti.TypeId = typeNVarChar
|
||||||
|
res.buffer = str2ucs2(val)
|
||||||
|
res.ti.Size = len(res.buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
|
||||||
|
if val == nil {
|
||||||
|
res.ti.TypeId = typeNull
|
||||||
res.buffer = nil
|
res.buffer = nil
|
||||||
res.ti.Size = 2
|
res.ti.Size = 0
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch val := val.(type) {
|
switch val := val.(type) {
|
||||||
|
@ -382,9 +696,7 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
|
||||||
res.ti.Size = len(val)
|
res.ti.Size = len(val)
|
||||||
res.buffer = val
|
res.buffer = val
|
||||||
case string:
|
case string:
|
||||||
res.ti.TypeId = typeNVarChar
|
res = makeStrParam(val)
|
||||||
res.buffer = str2ucs2(val)
|
|
||||||
res.ti.Size = len(res.buffer)
|
|
||||||
case bool:
|
case bool:
|
||||||
res.ti.TypeId = typeBitN
|
res.ti.TypeId = typeBitN
|
||||||
res.ti.Size = 1
|
res.ti.Size = 1
|
||||||
|
@ -425,22 +737,21 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
|
||||||
binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
|
binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("mssql: unknown type for %T", val)
|
return s.makeParamExtra(val)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type MssqlResult struct {
|
type Result struct {
|
||||||
c *MssqlConn
|
c *Conn
|
||||||
rowsAffected int64
|
rowsAffected int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MssqlResult) RowsAffected() (int64, error) {
|
func (r *Result) RowsAffected() (int64, error) {
|
||||||
return r.rowsAffected, nil
|
return r.rowsAffected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MssqlResult) LastInsertId() (int64, error) {
|
func (r *Result) LastInsertId() (int64, error) {
|
||||||
s, err := r.c.Prepare("select cast(@@identity as bigint)")
|
s, err := r.c.Prepare("select cast(@@identity as bigint)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
// +build go1.3
|
|
||||||
|
|
||||||
package mssql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func createDialer(p connectParams) *net.Dialer {
|
|
||||||
return &net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
// +build !go1.3
|
|
||||||
|
|
||||||
package mssql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func createDialer(p *connectParams) *net.Dialer {
|
|
||||||
return &net.Dialer{Timeout: p.dial_timeout}
|
|
||||||
}
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ driver.Pinger = &Conn{}
|
||||||
|
|
||||||
|
// Ping is used to check if the remote server is available and satisfies the Pinger interface.
|
||||||
|
func (c *Conn) Ping(ctx context.Context) error {
|
||||||
|
if !c.connectionGood {
|
||||||
|
return driver.ErrBadConn
|
||||||
|
}
|
||||||
|
stmt := &Stmt{c, `select 1;`, 0, nil}
|
||||||
|
_, err := stmt.ExecContext(ctx, nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ driver.ConnBeginTx = &Conn{}
|
||||||
|
|
||||||
|
// BeginTx satisfies ConnBeginTx.
|
||||||
|
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||||
|
if !c.connectionGood {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
if opts.ReadOnly {
|
||||||
|
return nil, errors.New("Read-only transactions are not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tdsIsolation isoLevel
|
||||||
|
switch sql.IsolationLevel(opts.Isolation) {
|
||||||
|
case sql.LevelDefault:
|
||||||
|
tdsIsolation = isolationUseCurrent
|
||||||
|
case sql.LevelReadUncommitted:
|
||||||
|
tdsIsolation = isolationReadUncommited
|
||||||
|
case sql.LevelReadCommitted:
|
||||||
|
tdsIsolation = isolationReadCommited
|
||||||
|
case sql.LevelWriteCommitted:
|
||||||
|
return nil, errors.New("LevelWriteCommitted isolation level is not supported")
|
||||||
|
case sql.LevelRepeatableRead:
|
||||||
|
tdsIsolation = isolationRepeatableRead
|
||||||
|
case sql.LevelSnapshot:
|
||||||
|
tdsIsolation = isolationSnapshot
|
||||||
|
case sql.LevelSerializable:
|
||||||
|
tdsIsolation = isolationSerializable
|
||||||
|
case sql.LevelLinearizable:
|
||||||
|
return nil, errors.New("LevelLinearizable isolation level is not supported")
|
||||||
|
default:
|
||||||
|
return nil, errors.New("Isolation level is not supported or unknown")
|
||||||
|
}
|
||||||
|
return c.begin(ctx, tdsIsolation)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||||
|
if !c.connectionGood {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
|
||||||
|
return c.prepareCopyIn(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.prepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
if !s.c.connectionGood {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
list := make([]namedValue, len(args))
|
||||||
|
for i, nv := range args {
|
||||||
|
list[i] = namedValue(nv)
|
||||||
|
}
|
||||||
|
return s.queryContext(ctx, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||||
|
if !s.c.connectionGood {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
list := make([]namedValue, len(args))
|
||||||
|
for i, nv := range args {
|
||||||
|
list[i] = namedValue(nv)
|
||||||
|
}
|
||||||
|
return s.exec(ctx, list)
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
// +build go1.9
|
||||||
|
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
// "github.com/cockroachdb/apd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type alias provided for compibility.
|
||||||
|
//
|
||||||
|
// Deprecated: users should transition to the new names when possible.
|
||||||
|
type MssqlDriver = Driver
|
||||||
|
type MssqlBulk = Bulk
|
||||||
|
type MssqlBulkOptions = BulkOptions
|
||||||
|
type MssqlConn = Conn
|
||||||
|
type MssqlResult = Result
|
||||||
|
type MssqlRows = Rows
|
||||||
|
type MssqlStmt = Stmt
|
||||||
|
|
||||||
|
var _ driver.NamedValueChecker = &Conn{}
|
||||||
|
|
||||||
|
func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
|
||||||
|
switch v := nv.Value.(type) {
|
||||||
|
case sql.Out:
|
||||||
|
if c.outs == nil {
|
||||||
|
c.outs = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
c.outs[nv.Name] = v.Dest
|
||||||
|
|
||||||
|
// Unwrap the Out value and check the inner value.
|
||||||
|
lnv := *nv
|
||||||
|
lnv.Value = v.Dest
|
||||||
|
err := c.CheckNamedValue(&lnv)
|
||||||
|
if err != nil {
|
||||||
|
if err != driver.ErrSkip {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
lnv.Value, err = driver.DefaultParameterConverter.ConvertValue(lnv.Value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nv.Value = sql.Out{Dest: lnv.Value}
|
||||||
|
return nil
|
||||||
|
// case *apd.Decimal:
|
||||||
|
// return nil
|
||||||
|
default:
|
||||||
|
return driver.ErrSkip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) makeParamExtra(val driver.Value) (res Param, err error) {
|
||||||
|
switch val := val.(type) {
|
||||||
|
case sql.Out:
|
||||||
|
res, err = s.makeParam(val.Dest)
|
||||||
|
res.Flags = fByRevValue
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("mssql: unknown type for %T", val)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
// +build !go1.9
|
||||||
|
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Stmt) makeParamExtra(val driver.Value) (Param, error) {
|
||||||
|
return Param{}, fmt.Errorf("mssql: unknown type for %T", val)
|
||||||
|
}
|
|
@ -33,7 +33,7 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
|
||||||
c.continueRead = false
|
c.continueRead = false
|
||||||
}
|
}
|
||||||
if !c.continueRead {
|
if !c.continueRead {
|
||||||
var packet uint8
|
var packet packetType
|
||||||
packet, err = c.buf.BeginRead()
|
packet, err = c.buf.BeginRead()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("Cannot read handshake packet: %s", err.Error())
|
err = fmt.Errorf("Cannot read handshake packet: %s", err.Error())
|
||||||
|
|
|
@ -59,7 +59,7 @@ type NTLMAuth struct {
|
||||||
Workstation string
|
Workstation string
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAuth(user, password, service, workstation string) (Auth, bool) {
|
func getAuth(user, password, service, workstation string) (auth, bool) {
|
||||||
if !strings.ContainsRune(user, '\\') {
|
if !strings.ContainsRune(user, '\\') {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,9 @@ type parser struct {
|
||||||
w bytes.Buffer
|
w bytes.Buffer
|
||||||
paramCount int
|
paramCount int
|
||||||
paramMax int
|
paramMax int
|
||||||
|
|
||||||
|
// using map as a set
|
||||||
|
namedParams map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *parser) next() (rune, bool) {
|
func (p *parser) next() (rune, bool) {
|
||||||
|
@ -40,12 +43,13 @@ type stateFunc func(*parser) stateFunc
|
||||||
func parseParams(query string) (string, int) {
|
func parseParams(query string) (string, int) {
|
||||||
p := &parser{
|
p := &parser{
|
||||||
r: bytes.NewReader([]byte(query)),
|
r: bytes.NewReader([]byte(query)),
|
||||||
|
namedParams: map[string]bool{},
|
||||||
}
|
}
|
||||||
state := parseNormal
|
state := parseNormal
|
||||||
for state != nil {
|
for state != nil {
|
||||||
state = state(p)
|
state = state(p)
|
||||||
}
|
}
|
||||||
return p.w.String(), p.paramMax
|
return p.w.String(), p.paramMax + len(p.namedParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNormal(p *parser) stateFunc {
|
func parseNormal(p *parser) stateFunc {
|
||||||
|
@ -55,7 +59,7 @@ func parseNormal(p *parser) stateFunc {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if ch == '?' {
|
if ch == '?' {
|
||||||
return parseParameter
|
return parseOrdinalParameter
|
||||||
} else if ch == '$' || ch == ':' {
|
} else if ch == '$' || ch == ':' {
|
||||||
ch2, ok := p.next()
|
ch2, ok := p.next()
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -64,7 +68,9 @@ func parseNormal(p *parser) stateFunc {
|
||||||
}
|
}
|
||||||
p.unread()
|
p.unread()
|
||||||
if ch2 >= '0' && ch2 <= '9' {
|
if ch2 >= '0' && ch2 <= '9' {
|
||||||
return parseParameter
|
return parseOrdinalParameter
|
||||||
|
} else if 'a' <= ch2 && ch2 <= 'z' || 'A' <= ch2 && ch2 <= 'Z' {
|
||||||
|
return parseNamedParameter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p.write(ch)
|
p.write(ch)
|
||||||
|
@ -83,7 +89,7 @@ func parseNormal(p *parser) stateFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseParameter(p *parser) stateFunc {
|
func parseOrdinalParameter(p *parser) stateFunc {
|
||||||
var paramN int
|
var paramN int
|
||||||
var ok bool
|
var ok bool
|
||||||
for {
|
for {
|
||||||
|
@ -113,6 +119,30 @@ func parseParameter(p *parser) stateFunc {
|
||||||
return parseNormal
|
return parseNormal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseNamedParameter(p *parser) stateFunc {
|
||||||
|
var paramName string
|
||||||
|
var ok bool
|
||||||
|
for {
|
||||||
|
var ch rune
|
||||||
|
ch, ok = p.next()
|
||||||
|
if ok && (ch >= '0' && ch <= '9' || 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z') {
|
||||||
|
paramName = paramName + string(ch)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
p.unread()
|
||||||
|
}
|
||||||
|
p.namedParams[paramName] = true
|
||||||
|
p.w.WriteString("@")
|
||||||
|
p.w.WriteString(paramName)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return parseNormal
|
||||||
|
}
|
||||||
|
|
||||||
func parseQuote(p *parser) stateFunc {
|
func parseQuote(p *parser) stateFunc {
|
||||||
for {
|
for {
|
||||||
ch, ok := p.next()
|
ch, ok := p.next()
|
||||||
|
|
|
@ -113,7 +113,7 @@ type SSPIAuth struct {
|
||||||
ctxt SecHandle
|
ctxt SecHandle
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAuth(user, password, service, workstation string) (Auth, bool) {
|
func getAuth(user, password, service, workstation string) (auth, bool) {
|
||||||
if user == "" {
|
if user == "" {
|
||||||
return &SSPIAuth{Service: service}, true
|
return &SSPIAuth{Service: service}, true
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
@ -9,11 +10,13 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode"
|
||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
@ -47,8 +50,11 @@ func parseInstances(msg []byte) map[string]map[string]string {
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInstances(address string) (map[string]map[string]string, error) {
|
func getInstances(ctx context.Context, address string) (map[string]map[string]string, error) {
|
||||||
conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
|
dialer := &net.Dialer{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", address+":1434")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -79,11 +85,16 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// packet types
|
// packet types
|
||||||
|
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
|
||||||
const (
|
const (
|
||||||
packSQLBatch = 1
|
packSQLBatch packetType = 1
|
||||||
packRPCRequest = 3
|
packRPCRequest = 3
|
||||||
packReply = 4
|
packReply = 4
|
||||||
packCancel = 6
|
|
||||||
|
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
|
||||||
|
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
|
||||||
|
packAttention = 6
|
||||||
|
|
||||||
packBulkLoadBCP = 7
|
packBulkLoadBCP = 7
|
||||||
packTransMgrReq = 14
|
packTransMgrReq = 14
|
||||||
packNormal = 15
|
packNormal = 15
|
||||||
|
@ -119,7 +130,7 @@ type tdsSession struct {
|
||||||
columns []columnStruct
|
columns []columnStruct
|
||||||
tranid uint64
|
tranid uint64
|
||||||
logFlags uint64
|
logFlags uint64
|
||||||
log *Logger
|
log optionalLogger
|
||||||
routedServer string
|
routedServer string
|
||||||
routedPort uint16
|
routedPort uint16
|
||||||
}
|
}
|
||||||
|
@ -131,6 +142,7 @@ const (
|
||||||
logSQL = 8
|
logSQL = 8
|
||||||
logParams = 16
|
logParams = 16
|
||||||
logTransaction = 32
|
logTransaction = 32
|
||||||
|
logDebug = 64
|
||||||
)
|
)
|
||||||
|
|
||||||
type columnStruct struct {
|
type columnStruct struct {
|
||||||
|
@ -490,6 +502,11 @@ func readBVarChar(r io.Reader) (res string, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A zero length could be returned, return an empty string
|
||||||
|
if numchars == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
return readUcs2(r, int(numchars))
|
return readUcs2(r, int(numchars))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -588,7 +605,7 @@ func (hdr transDescrHdr) pack() (res []byte) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
|
func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
|
||||||
// calculatint total length
|
// Calculating total length.
|
||||||
var totallen uint32 = 4
|
var totallen uint32 = 4
|
||||||
for _, hdr := range headers {
|
for _, hdr := range headers {
|
||||||
totallen += 4 + 2 + uint32(len(hdr.data))
|
totallen += 4 + 2 + uint32(len(hdr.data))
|
||||||
|
@ -616,9 +633,7 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendSqlBatch72(buf *tdsBuffer,
|
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
|
||||||
sqltext string,
|
|
||||||
headers []headerStruct) (err error) {
|
|
||||||
buf.BeginPacket(packSQLBatch)
|
buf.BeginPacket(packSQLBatch)
|
||||||
|
|
||||||
if err = writeAllHeaders(buf, headers); err != nil {
|
if err = writeAllHeaders(buf, headers); err != nil {
|
||||||
|
@ -632,6 +647,13 @@ func sendSqlBatch72(buf *tdsBuffer,
|
||||||
return buf.FinishPacket()
|
return buf.FinishPacket()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
|
||||||
|
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
|
||||||
|
func sendAttention(buf *tdsBuffer) error {
|
||||||
|
buf.BeginPacket(packAttention)
|
||||||
|
return buf.FinishPacket()
|
||||||
|
}
|
||||||
|
|
||||||
type connectParams struct {
|
type connectParams struct {
|
||||||
logFlags uint64
|
logFlags uint64
|
||||||
port uint64
|
port uint64
|
||||||
|
@ -654,6 +676,7 @@ type connectParams struct {
|
||||||
typeFlags uint8
|
typeFlags uint8
|
||||||
failOverPartner string
|
failOverPartner string
|
||||||
failOverPort uint64
|
failOverPort uint64
|
||||||
|
packetSize uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func splitConnectionString(dsn string) (res map[string]string) {
|
func splitConnectionString(dsn string) (res map[string]string) {
|
||||||
|
@ -677,9 +700,241 @@ func splitConnectionString(dsn string) (res map[string]string) {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Splits a URL in the ODBC format
|
||||||
|
func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
||||||
|
res := map[string]string{}
|
||||||
|
|
||||||
|
type parserState int
|
||||||
|
const (
|
||||||
|
// Before the start of a key
|
||||||
|
parserStateBeforeKey parserState = iota
|
||||||
|
|
||||||
|
// Inside a key
|
||||||
|
parserStateKey
|
||||||
|
|
||||||
|
// Beginning of a value. May be bare or braced
|
||||||
|
parserStateBeginValue
|
||||||
|
|
||||||
|
// Inside a bare value
|
||||||
|
parserStateBareValue
|
||||||
|
|
||||||
|
// Inside a braced value
|
||||||
|
parserStateBracedValue
|
||||||
|
|
||||||
|
// A closing brace inside a braced value.
|
||||||
|
// May be the end of the value or an escaped closing brace, depending on the next character
|
||||||
|
parserStateBracedValueClosingBrace
|
||||||
|
|
||||||
|
// After a value. Next character should be a semicolon or whitespace.
|
||||||
|
parserStateEndValue
|
||||||
|
)
|
||||||
|
|
||||||
|
var state = parserStateBeforeKey
|
||||||
|
|
||||||
|
var key string
|
||||||
|
var value string
|
||||||
|
|
||||||
|
for i, c := range dsn {
|
||||||
|
switch state {
|
||||||
|
case parserStateBeforeKey:
|
||||||
|
switch {
|
||||||
|
case c == '=':
|
||||||
|
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
|
||||||
|
case !unicode.IsSpace(c) && c != ';':
|
||||||
|
state = parserStateKey
|
||||||
|
key += string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
case parserStateKey:
|
||||||
|
switch c {
|
||||||
|
case '=':
|
||||||
|
key = normalizeOdbcKey(key)
|
||||||
|
if len(key) == 0 {
|
||||||
|
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
state = parserStateBeginValue
|
||||||
|
|
||||||
|
case ';':
|
||||||
|
// Key without value
|
||||||
|
key = normalizeOdbcKey(key)
|
||||||
|
if len(key) == 0 {
|
||||||
|
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
res[key] = value
|
||||||
|
key = ""
|
||||||
|
value = ""
|
||||||
|
state = parserStateBeforeKey
|
||||||
|
|
||||||
|
default:
|
||||||
|
key += string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
case parserStateBeginValue:
|
||||||
|
switch {
|
||||||
|
case c == '{':
|
||||||
|
state = parserStateBracedValue
|
||||||
|
case c == ';':
|
||||||
|
// Empty value
|
||||||
|
res[key] = value
|
||||||
|
key = ""
|
||||||
|
state = parserStateBeforeKey
|
||||||
|
case unicode.IsSpace(c):
|
||||||
|
// Ignore whitespace
|
||||||
|
default:
|
||||||
|
state = parserStateBareValue
|
||||||
|
value += string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
case parserStateBareValue:
|
||||||
|
if c == ';' {
|
||||||
|
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
||||||
|
key = ""
|
||||||
|
value = ""
|
||||||
|
state = parserStateBeforeKey
|
||||||
|
} else {
|
||||||
|
value += string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
case parserStateBracedValue:
|
||||||
|
if c == '}' {
|
||||||
|
state = parserStateBracedValueClosingBrace
|
||||||
|
} else {
|
||||||
|
value += string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
case parserStateBracedValueClosingBrace:
|
||||||
|
if c == '}' {
|
||||||
|
// Escaped closing brace
|
||||||
|
value += string(c)
|
||||||
|
state = parserStateBracedValue
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// End of braced value
|
||||||
|
res[key] = value
|
||||||
|
key = ""
|
||||||
|
value = ""
|
||||||
|
|
||||||
|
// This character is the first character past the end,
|
||||||
|
// so it needs to be parsed like the parserStateEndValue state.
|
||||||
|
state = parserStateEndValue
|
||||||
|
switch {
|
||||||
|
case c == ';':
|
||||||
|
state = parserStateBeforeKey
|
||||||
|
case unicode.IsSpace(c):
|
||||||
|
// Ignore whitespace
|
||||||
|
default:
|
||||||
|
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
case parserStateEndValue:
|
||||||
|
switch {
|
||||||
|
case c == ';':
|
||||||
|
state = parserStateBeforeKey
|
||||||
|
case unicode.IsSpace(c):
|
||||||
|
// Ignore whitespace
|
||||||
|
default:
|
||||||
|
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case parserStateBeforeKey: // Okay
|
||||||
|
case parserStateKey: // Unfinished key. Treat as key without value.
|
||||||
|
key = normalizeOdbcKey(key)
|
||||||
|
if len(key) == 0 {
|
||||||
|
return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
|
||||||
|
}
|
||||||
|
res[key] = value
|
||||||
|
case parserStateBeginValue: // Empty value
|
||||||
|
res[key] = value
|
||||||
|
case parserStateBareValue:
|
||||||
|
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
||||||
|
case parserStateBracedValue:
|
||||||
|
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
|
||||||
|
case parserStateBracedValueClosingBrace: // End of braced value
|
||||||
|
res[key] = value
|
||||||
|
case parserStateEndValue: // Okay
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalizes the given string as an ODBC-format key
|
||||||
|
func normalizeOdbcKey(s string) string {
|
||||||
|
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value
|
||||||
|
func splitConnectionStringURL(dsn string) (map[string]string, error) {
|
||||||
|
res := map[string]string{}
|
||||||
|
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme != "sqlserver" {
|
||||||
|
return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.User != nil {
|
||||||
|
res["user id"] = u.User.Username()
|
||||||
|
p, exists := u.User.Password()
|
||||||
|
if exists {
|
||||||
|
res["password"] = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
host = u.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(u.Path) > 0 {
|
||||||
|
res["server"] = host + "\\" + u.Path[1:]
|
||||||
|
} else {
|
||||||
|
res["server"] = host
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(port) > 0 {
|
||||||
|
res["port"] = port
|
||||||
|
}
|
||||||
|
|
||||||
|
query := u.Query()
|
||||||
|
for k, v := range query {
|
||||||
|
if len(v) > 1 {
|
||||||
|
return res, fmt.Errorf("key %s provided more than once", k)
|
||||||
|
}
|
||||||
|
res[strings.ToLower(k)] = v[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseConnectParams(dsn string) (connectParams, error) {
|
func parseConnectParams(dsn string) (connectParams, error) {
|
||||||
params := splitConnectionString(dsn)
|
|
||||||
var p connectParams
|
var p connectParams
|
||||||
|
|
||||||
|
var params map[string]string
|
||||||
|
if strings.HasPrefix(dsn, "odbc:") {
|
||||||
|
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
|
||||||
|
if err != nil {
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
params = parameters
|
||||||
|
} else if strings.HasPrefix(dsn, "sqlserver://") {
|
||||||
|
parameters, err := splitConnectionStringURL(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
params = parameters
|
||||||
|
} else {
|
||||||
|
params = splitConnectionString(dsn)
|
||||||
|
}
|
||||||
|
|
||||||
strlog, ok := params["log"]
|
strlog, ok := params["log"]
|
||||||
if ok {
|
if ok {
|
||||||
var err error
|
var err error
|
||||||
|
@ -712,7 +967,32 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.dial_timeout = 5 * time.Second
|
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
|
||||||
|
// Default packet size remains at 4096 bytes
|
||||||
|
p.packetSize = 4096
|
||||||
|
strpsize, ok := params["packet size"]
|
||||||
|
if ok {
|
||||||
|
var err error
|
||||||
|
psize, err := strconv.ParseUint(strpsize, 0, 16)
|
||||||
|
if err != nil {
|
||||||
|
f := "Invalid packet size '%v': %v"
|
||||||
|
return p, fmt.Errorf(f, strpsize, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
|
||||||
|
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
|
||||||
|
// a higher packet size, the server will respond with an ENVCHANGE request to
|
||||||
|
// alter the packet size to 16383 bytes.
|
||||||
|
p.packetSize = uint16(psize)
|
||||||
|
if p.packetSize < 512 {
|
||||||
|
p.packetSize = 512
|
||||||
|
} else if p.packetSize > 32767 {
|
||||||
|
p.packetSize = 32767
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
||||||
|
p.dial_timeout = 15 * time.Second
|
||||||
p.conn_timeout = 30 * time.Second
|
p.conn_timeout = 30 * time.Second
|
||||||
strconntimeout, ok := params["connection timeout"]
|
strconntimeout, ok := params["connection timeout"]
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -732,8 +1012,12 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
||||||
}
|
}
|
||||||
p.dial_timeout = time.Duration(timeout) * time.Second
|
p.dial_timeout = time.Duration(timeout) * time.Second
|
||||||
}
|
}
|
||||||
keepAlive, ok := params["keepalive"]
|
|
||||||
if ok {
|
// default keep alive should be 30 seconds according to spec:
|
||||||
|
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
||||||
|
p.keepAlive = 30 * time.Second
|
||||||
|
|
||||||
|
if keepAlive, ok := params["keepalive"]; ok {
|
||||||
timeout, err := strconv.ParseUint(keepAlive, 0, 16)
|
timeout, err := strconv.ParseUint(keepAlive, 0, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f := "Invalid keepAlive value '%s': %s"
|
f := "Invalid keepAlive value '%s': %s"
|
||||||
|
@ -743,7 +1027,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
||||||
}
|
}
|
||||||
encrypt, ok := params["encrypt"]
|
encrypt, ok := params["encrypt"]
|
||||||
if ok {
|
if ok {
|
||||||
if strings.ToUpper(encrypt) == "DISABLE" {
|
if strings.EqualFold(encrypt, "DISABLE") {
|
||||||
p.disableEncryption = true
|
p.disableEncryption = true
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
|
@ -819,7 +1103,7 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Auth interface {
|
type auth interface {
|
||||||
InitialBytes() ([]byte, error)
|
InitialBytes() ([]byte, error)
|
||||||
NextBytes([]byte) ([]byte, error)
|
NextBytes([]byte) ([]byte, error)
|
||||||
Free()
|
Free()
|
||||||
|
@ -828,7 +1112,7 @@ type Auth interface {
|
||||||
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
|
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
|
||||||
// list of IP addresses. So if there is more than one, try them all and
|
// list of IP addresses. So if there is more than one, try them all and
|
||||||
// use the first one that allows a connection.
|
// use the first one that allows a connection.
|
||||||
func dialConnection(p connectParams) (conn net.Conn, err error) {
|
func dialConnection(ctx context.Context, p connectParams) (conn net.Conn, err error) {
|
||||||
var ips []net.IP
|
var ips []net.IP
|
||||||
ips, err = net.LookupIP(p.host)
|
ips, err = net.LookupIP(p.host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -839,9 +1123,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
|
||||||
ips = []net.IP{ip}
|
ips = []net.IP{ip}
|
||||||
}
|
}
|
||||||
if len(ips) == 1 {
|
if len(ips) == 1 {
|
||||||
d := createDialer(p)
|
d := createDialer(&p)
|
||||||
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
|
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
|
||||||
conn, err = d.Dial("tcp", addr)
|
conn, err = d.Dial(ctx, addr)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
//Try Dials in parallel to avoid waiting for timeouts.
|
//Try Dials in parallel to avoid waiting for timeouts.
|
||||||
|
@ -850,9 +1134,9 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
|
||||||
portStr := strconv.Itoa(int(p.port))
|
portStr := strconv.Itoa(int(p.port))
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
go func(ip net.IP) {
|
go func(ip net.IP) {
|
||||||
d := createDialer(p)
|
d := createDialer(&p)
|
||||||
addr := net.JoinHostPort(ip.String(), portStr)
|
addr := net.JoinHostPort(ip.String(), portStr)
|
||||||
conn, err := d.Dial("tcp", addr)
|
conn, err := d.Dial(ctx, addr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
connChan <- conn
|
connChan <- conn
|
||||||
} else {
|
} else {
|
||||||
|
@ -887,16 +1171,15 @@ func dialConnection(p connectParams) (conn net.Conn, err error) {
|
||||||
f := "Unable to open tcp connection with host '%v:%v': %v"
|
f := "Unable to open tcp connection with host '%v:%v': %v"
|
||||||
return nil, fmt.Errorf(f, p.host, p.port, err.Error())
|
return nil, fmt.Errorf(f, p.host, p.port, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(p connectParams) (res *tdsSession, err error) {
|
func connect(ctx context.Context, log optionalLogger, p connectParams) (res *tdsSession, err error) {
|
||||||
res = nil
|
res = nil
|
||||||
// if instance is specified use instance resolution service
|
// if instance is specified use instance resolution service
|
||||||
if p.instance != "" {
|
if p.instance != "" {
|
||||||
p.instance = strings.ToUpper(p.instance)
|
p.instance = strings.ToUpper(p.instance)
|
||||||
instances, err := getInstances(p.host)
|
instances, err := getInstances(ctx, p.host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f := "Unable to get instances from Sql Server Browser on host %v: %v"
|
f := "Unable to get instances from Sql Server Browser on host %v: %v"
|
||||||
return nil, fmt.Errorf(f, p.host, err.Error())
|
return nil, fmt.Errorf(f, p.host, err.Error())
|
||||||
|
@ -914,16 +1197,17 @@ func connect(p connectParams) (res *tdsSession, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
initiate_connection:
|
initiate_connection:
|
||||||
conn, err := dialConnection(p)
|
conn, err := dialConnection(ctx, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
toconn := NewTimeoutConn(conn, p.conn_timeout)
|
toconn := NewTimeoutConn(conn, p.conn_timeout)
|
||||||
|
|
||||||
outbuf := newTdsBuffer(4096, toconn)
|
outbuf := newTdsBuffer(p.packetSize, toconn)
|
||||||
sess := tdsSession{
|
sess := tdsSession{
|
||||||
buf: outbuf,
|
buf: outbuf,
|
||||||
|
log: log,
|
||||||
logFlags: p.logFlags,
|
logFlags: p.logFlags,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -969,8 +1253,7 @@ initiate_connection:
|
||||||
if p.certificate != "" {
|
if p.certificate != "" {
|
||||||
pem, err := ioutil.ReadFile(p.certificate)
|
pem, err := ioutil.ReadFile(p.certificate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f := "Cannot read certificate '%s': %s"
|
return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
|
||||||
return nil, fmt.Errorf(f, p.certificate, err.Error())
|
|
||||||
}
|
}
|
||||||
certs := x509.NewCertPool()
|
certs := x509.NewCertPool()
|
||||||
certs.AppendCertsFromPEM(pem)
|
certs.AppendCertsFromPEM(pem)
|
||||||
|
@ -980,15 +1263,20 @@ initiate_connection:
|
||||||
config.InsecureSkipVerify = true
|
config.InsecureSkipVerify = true
|
||||||
}
|
}
|
||||||
config.ServerName = p.hostInCertificate
|
config.ServerName = p.hostInCertificate
|
||||||
|
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
|
||||||
|
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
|
||||||
|
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
|
||||||
|
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
|
||||||
|
config.DynamicRecordSizingDisabled = true
|
||||||
outbuf.transport = conn
|
outbuf.transport = conn
|
||||||
toconn.buf = outbuf
|
toconn.buf = outbuf
|
||||||
tlsConn := tls.Client(toconn, &config)
|
tlsConn := tls.Client(toconn, &config)
|
||||||
err = tlsConn.Handshake()
|
err = tlsConn.Handshake()
|
||||||
|
|
||||||
toconn.buf = nil
|
toconn.buf = nil
|
||||||
outbuf.transport = tlsConn
|
outbuf.transport = tlsConn
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f := "TLS Handshake failed: %s"
|
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
|
||||||
return nil, fmt.Errorf(f, err.Error())
|
|
||||||
}
|
}
|
||||||
if encrypt == encryptOff {
|
if encrypt == encryptOff {
|
||||||
outbuf.afterFirst = func() {
|
outbuf.afterFirst = func() {
|
||||||
|
@ -999,7 +1287,7 @@ initiate_connection:
|
||||||
|
|
||||||
login := login{
|
login := login{
|
||||||
TDSVersion: verTDS74,
|
TDSVersion: verTDS74,
|
||||||
PacketSize: uint32(len(outbuf.buf)),
|
PacketSize: uint32(outbuf.PackageSize()),
|
||||||
Database: p.database,
|
Database: p.database,
|
||||||
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
|
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
|
||||||
HostName: p.workstation,
|
HostName: p.workstation,
|
||||||
|
@ -1028,7 +1316,7 @@ initiate_connection:
|
||||||
var sspi_msg []byte
|
var sspi_msg []byte
|
||||||
continue_login:
|
continue_login:
|
||||||
tokchan := make(chan tokenStruct, 5)
|
tokchan := make(chan tokenStruct, 5)
|
||||||
go processResponse(&sess, tokchan)
|
go processResponse(context.Background(), &sess, tokchan, nil)
|
||||||
success := false
|
success := false
|
||||||
for tok := range tokchan {
|
for tok := range tokchan {
|
||||||
switch token := tok.(type) {
|
switch token := tok.(type) {
|
||||||
|
@ -1042,6 +1330,10 @@ continue_login:
|
||||||
sess.loginAck = token
|
sess.loginAck = token
|
||||||
case error:
|
case error:
|
||||||
return nil, fmt.Errorf("Login error: %s", token.Error())
|
return nil, fmt.Errorf("Login error: %s", token.Error())
|
||||||
|
case doneStruct:
|
||||||
|
if token.isError() {
|
||||||
|
return nil, fmt.Errorf("Login error: %s", token.getError())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if sspi_msg != nil {
|
if sspi_msg != nil {
|
||||||
|
|
|
@ -1,30 +1,40 @@
|
||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//go:generate stringer -type token
|
||||||
|
|
||||||
|
type token byte
|
||||||
|
|
||||||
// token ids
|
// token ids
|
||||||
const (
|
const (
|
||||||
tokenReturnStatus = 121 // 0x79
|
tokenReturnStatus token = 121 // 0x79
|
||||||
tokenColMetadata = 129 // 0x81
|
tokenColMetadata token = 129 // 0x81
|
||||||
tokenOrder = 169 // 0xA9
|
tokenOrder token = 169 // 0xA9
|
||||||
tokenError = 170 // 0xAA
|
tokenError token = 170 // 0xAA
|
||||||
tokenInfo = 171 // 0xAB
|
tokenInfo token = 171 // 0xAB
|
||||||
tokenLoginAck = 173 // 0xad
|
tokenReturnValue token = 0xAC
|
||||||
tokenRow = 209 // 0xd1
|
tokenLoginAck token = 173 // 0xad
|
||||||
tokenNbcRow = 210 // 0xd2
|
tokenRow token = 209 // 0xd1
|
||||||
tokenEnvChange = 227 // 0xE3
|
tokenNbcRow token = 210 // 0xd2
|
||||||
tokenSSPI = 237 // 0xED
|
tokenEnvChange token = 227 // 0xE3
|
||||||
tokenDone = 253 // 0xFD
|
tokenSSPI token = 237 // 0xED
|
||||||
tokenDoneProc = 254
|
tokenDone token = 253 // 0xFD
|
||||||
tokenDoneInProc = 255
|
tokenDoneProc token = 254
|
||||||
|
tokenDoneInProc token = 255
|
||||||
)
|
)
|
||||||
|
|
||||||
// done flags
|
// done flags
|
||||||
|
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
|
||||||
const (
|
const (
|
||||||
doneFinal = 0
|
doneFinal = 0
|
||||||
doneMore = 1
|
doneMore = 1
|
||||||
|
@ -59,6 +69,13 @@ const (
|
||||||
envRouting = 20
|
envRouting = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// COLMETADATA flags
|
||||||
|
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
|
||||||
|
const (
|
||||||
|
colFlagNullable = 1
|
||||||
|
// TODO implement more flags
|
||||||
|
)
|
||||||
|
|
||||||
// interface for all tokens
|
// interface for all tokens
|
||||||
type tokenStruct interface{}
|
type tokenStruct interface{}
|
||||||
|
|
||||||
|
@ -70,6 +87,19 @@ type doneStruct struct {
|
||||||
Status uint16
|
Status uint16
|
||||||
CurCmd uint16
|
CurCmd uint16
|
||||||
RowCount uint64
|
RowCount uint64
|
||||||
|
errors []Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d doneStruct) isError() bool {
|
||||||
|
return d.Status&doneError != 0 || len(d.errors) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d doneStruct) getError() Error {
|
||||||
|
if len(d.errors) > 0 {
|
||||||
|
return d.errors[len(d.errors)-1]
|
||||||
|
} else {
|
||||||
|
return Error{Message: "Request failed but didn't provide reason"}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type doneInProcStruct doneStruct
|
type doneInProcStruct doneStruct
|
||||||
|
@ -120,27 +150,23 @@ func processEnvChg(sess *tdsSession) {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
case envTypLanguage:
|
case envTypLanguage:
|
||||||
//currently ignored
|
// currently ignored
|
||||||
// old value
|
// new value
|
||||||
_, err = readBVarChar(r)
|
if _, err = readBVarChar(r); err != nil {
|
||||||
if err != nil {
|
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
// new value
|
// old value
|
||||||
_, err = readBVarChar(r)
|
if _, err = readBVarChar(r); err != nil {
|
||||||
if err != nil {
|
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
case envTypCharset:
|
case envTypCharset:
|
||||||
//currently ignored
|
// currently ignored
|
||||||
// old value
|
// new value
|
||||||
_, err = readBVarChar(r)
|
if _, err = readBVarChar(r); err != nil {
|
||||||
if err != nil {
|
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
// new value
|
// old value
|
||||||
_, err = readBVarChar(r)
|
if _, err = readBVarChar(r); err != nil {
|
||||||
if err != nil {
|
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
case envTypPacketSize:
|
case envTypPacketSize:
|
||||||
|
@ -156,38 +182,55 @@ func processEnvChg(sess *tdsSession) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error())
|
badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error())
|
||||||
}
|
}
|
||||||
if len(sess.buf.buf) != packetsizei {
|
sess.buf.ResizeBuffer(packetsizei)
|
||||||
newbuf := make([]byte, packetsizei)
|
|
||||||
copy(newbuf, sess.buf.buf)
|
|
||||||
sess.buf.buf = newbuf
|
|
||||||
}
|
|
||||||
case envSortId:
|
case envSortId:
|
||||||
// currently ignored
|
// currently ignored
|
||||||
// old value, should be 0
|
// new value
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
// new value
|
// old value, should be 0
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
case envSortFlags:
|
case envSortFlags:
|
||||||
// currently ignored
|
// currently ignored
|
||||||
// old value, should be 0
|
// new value
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
// new value
|
// old value, should be 0
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
case envSqlCollation:
|
case envSqlCollation:
|
||||||
// currently ignored
|
// currently ignored
|
||||||
// old value
|
var collationSize uint8
|
||||||
if _, err = readBVarChar(r); err != nil {
|
err = binary.Read(r, binary.LittleEndian, &collationSize)
|
||||||
|
if err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
// new value
|
|
||||||
|
// SQL Collation data should contain 5 bytes in length
|
||||||
|
if collationSize != 5 {
|
||||||
|
badStreamPanicf("Invalid SQL Collation size value returned from server: %s", collationSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4 bytes, contains: LCID ColFlags Version
|
||||||
|
var info uint32
|
||||||
|
err = binary.Read(r, binary.LittleEndian, &info)
|
||||||
|
if err != nil {
|
||||||
|
badStreamPanic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1 byte, contains: sortID
|
||||||
|
var sortID uint8
|
||||||
|
err = binary.Read(r, binary.LittleEndian, &sortID)
|
||||||
|
if err != nil {
|
||||||
|
badStreamPanic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// old value, should be 0
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
|
@ -226,21 +269,21 @@ func processEnvChg(sess *tdsSession) {
|
||||||
sess.tranid = 0
|
sess.tranid = 0
|
||||||
case envEnlistDTC:
|
case envEnlistDTC:
|
||||||
// currently ignored
|
// currently ignored
|
||||||
// old value
|
// new value, should be 0
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
// new value, should be 0
|
// old value
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
case envDefectTran:
|
case envDefectTran:
|
||||||
// currently ignored
|
// currently ignored
|
||||||
// old value, should be 0
|
// new value
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
// new value
|
// old value, should be 0
|
||||||
if _, err = readBVarChar(r); err != nil {
|
if _, err = readBVarChar(r); err != nil {
|
||||||
badStreamPanic(err)
|
badStreamPanic(err)
|
||||||
}
|
}
|
||||||
|
@ -358,6 +401,7 @@ func parseOrder(r *tdsBuffer) (res orderStruct) {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
|
||||||
func parseDone(r *tdsBuffer) (res doneStruct) {
|
func parseDone(r *tdsBuffer) (res doneStruct) {
|
||||||
res.Status = r.uint16()
|
res.Status = r.uint16()
|
||||||
res.CurCmd = r.uint16()
|
res.CurCmd = r.uint16()
|
||||||
|
@ -365,6 +409,7 @@ func parseDone(r *tdsBuffer) (res doneStruct) {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://msdn.microsoft.com/en-us/library/dd340553.aspx
|
||||||
func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) {
|
func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) {
|
||||||
res.Status = r.uint16()
|
res.Status = r.uint16()
|
||||||
res.CurCmd = r.uint16()
|
res.CurCmd = r.uint16()
|
||||||
|
@ -473,26 +518,57 @@ func parseInfo(r *tdsBuffer) (res Error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func processResponse(sess *tdsSession, ch chan tokenStruct) {
|
// https://msdn.microsoft.com/en-us/library/dd303881.aspx
|
||||||
|
func parseReturnValue(r *tdsBuffer) (nv namedValue) {
|
||||||
|
/*
|
||||||
|
ParamOrdinal
|
||||||
|
ParamName
|
||||||
|
Status
|
||||||
|
UserType
|
||||||
|
Flags
|
||||||
|
TypeInfo
|
||||||
|
CryptoMetadata
|
||||||
|
Value
|
||||||
|
*/
|
||||||
|
r.uint16()
|
||||||
|
nv.Name = r.BVarChar()
|
||||||
|
r.byte()
|
||||||
|
r.uint32() // UserType (uint16 prior to 7.2)
|
||||||
|
r.uint16()
|
||||||
|
ti := readTypeInfo(r)
|
||||||
|
nv.Value = ti.Reader(&ti, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
|
if sess.logFlags&logErrors != 0 {
|
||||||
|
sess.log.Printf("ERROR: Intercepted panic %v", err)
|
||||||
|
}
|
||||||
ch <- err
|
ch <- err
|
||||||
}
|
}
|
||||||
close(ch)
|
close(ch)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
packet_type, err := sess.buf.BeginRead()
|
packet_type, err := sess.buf.BeginRead()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if sess.logFlags&logErrors != 0 {
|
||||||
|
sess.log.Printf("ERROR: BeginRead failed %v", err)
|
||||||
|
}
|
||||||
ch <- err
|
ch <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if packet_type != packReply {
|
if packet_type != packReply {
|
||||||
badStreamPanicf("invalid response packet type, expected REPLY, actual: %d", packet_type)
|
badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply))
|
||||||
}
|
}
|
||||||
var columns []columnStruct
|
var columns []columnStruct
|
||||||
var lastError Error
|
errs := make([]Error, 0, 5)
|
||||||
var failed bool
|
|
||||||
for {
|
for {
|
||||||
token := sess.buf.byte()
|
token := token(sess.buf.byte())
|
||||||
|
if sess.logFlags&logDebug != 0 {
|
||||||
|
sess.log.Printf("got token %v", token)
|
||||||
|
}
|
||||||
switch token {
|
switch token {
|
||||||
case tokenSSPI:
|
case tokenSSPI:
|
||||||
ch <- parseSSPIMsg(sess.buf)
|
ch <- parseSSPIMsg(sess.buf)
|
||||||
|
@ -514,18 +590,17 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
|
||||||
ch <- done
|
ch <- done
|
||||||
case tokenDone, tokenDoneProc:
|
case tokenDone, tokenDoneProc:
|
||||||
done := parseDone(sess.buf)
|
done := parseDone(sess.buf)
|
||||||
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
|
done.errors = errs
|
||||||
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
|
if sess.logFlags&logDebug != 0 {
|
||||||
}
|
sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
|
||||||
if done.Status&doneError != 0 || failed {
|
|
||||||
ch <- lastError
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if done.Status&doneSrvError != 0 {
|
if done.Status&doneSrvError != 0 {
|
||||||
lastError.Message = "Server Error"
|
ch <- errors.New("SQL Server had internal error")
|
||||||
ch <- lastError
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
|
||||||
|
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
|
||||||
|
}
|
||||||
ch <- done
|
ch <- done
|
||||||
if done.Status&doneMore == 0 {
|
if done.Status&doneMore == 0 {
|
||||||
return
|
return
|
||||||
|
@ -544,18 +619,210 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
|
||||||
case tokenEnvChange:
|
case tokenEnvChange:
|
||||||
processEnvChg(sess)
|
processEnvChg(sess)
|
||||||
case tokenError:
|
case tokenError:
|
||||||
lastError = parseError72(sess.buf)
|
err := parseError72(sess.buf)
|
||||||
failed = true
|
if sess.logFlags&logDebug != 0 {
|
||||||
|
sess.log.Printf("got ERROR %d %s", err.Number, err.Message)
|
||||||
|
}
|
||||||
|
errs = append(errs, err)
|
||||||
if sess.logFlags&logErrors != 0 {
|
if sess.logFlags&logErrors != 0 {
|
||||||
sess.log.Println(lastError.Message)
|
sess.log.Println(err.Message)
|
||||||
}
|
}
|
||||||
case tokenInfo:
|
case tokenInfo:
|
||||||
info := parseInfo(sess.buf)
|
info := parseInfo(sess.buf)
|
||||||
|
if sess.logFlags&logDebug != 0 {
|
||||||
|
sess.log.Printf("got INFO %d %s", info.Number, info.Message)
|
||||||
|
}
|
||||||
if sess.logFlags&logMessages != 0 {
|
if sess.logFlags&logMessages != 0 {
|
||||||
sess.log.Println(info.Message)
|
sess.log.Println(info.Message)
|
||||||
}
|
}
|
||||||
|
case tokenReturnValue:
|
||||||
|
nv := parseReturnValue(sess.buf)
|
||||||
|
if len(nv.Name) > 0 {
|
||||||
|
name := nv.Name[1:] // Remove the leading "@".
|
||||||
|
if ov, has := outs[name]; has {
|
||||||
|
err = scanIntoOut(nv.Value, ov)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("scan error", err)
|
||||||
|
ch <- err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
badStreamPanicf("Unknown token type: %d", token)
|
badStreamPanic(fmt.Errorf("unknown token type returned: %v", token))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanIntoOut(fromServer, scanInto interface{}) error {
|
||||||
|
switch fs := fromServer.(type) {
|
||||||
|
case int64:
|
||||||
|
switch si := scanInto.(type) {
|
||||||
|
case *int64:
|
||||||
|
*si = fs
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
switch si := scanInto.(type) {
|
||||||
|
case *string:
|
||||||
|
*si = fs
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unsupported type from server %[1]T=%[1]v", fromServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
type parseRespIter byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
parseRespIterContinue parseRespIter = iota // Continue parsing current token.
|
||||||
|
parseRespIterNext // Fetch the next token.
|
||||||
|
parseRespIterDone // Done with parsing the response.
|
||||||
|
)
|
||||||
|
|
||||||
|
type parseRespState byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
parseRespStateNormal parseRespState = iota // Normal response state.
|
||||||
|
parseRespStateCancel // Query is canceled, wait for server to confirm.
|
||||||
|
parseRespStateClosing // Waiting for tokens to come through.
|
||||||
|
)
|
||||||
|
|
||||||
|
type parseResp struct {
|
||||||
|
sess *tdsSession
|
||||||
|
ctxDone <-chan struct{}
|
||||||
|
state parseRespState
|
||||||
|
cancelError error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
|
||||||
|
if err := sendAttention(ts.sess.buf); err != nil {
|
||||||
|
ts.dlogf("failed to send attention signal %v", err)
|
||||||
|
ch <- err
|
||||||
|
return parseRespIterDone
|
||||||
|
}
|
||||||
|
ts.state = parseRespStateCancel
|
||||||
|
return parseRespIterContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *parseResp) dlog(msg string) {
|
||||||
|
if ts.sess.logFlags&logDebug != 0 {
|
||||||
|
ts.sess.log.Println(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (ts *parseResp) dlogf(f string, v ...interface{}) {
|
||||||
|
if ts.sess.logFlags&logDebug != 0 {
|
||||||
|
ts.sess.log.Printf(f, v...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
|
||||||
|
switch ts.state {
|
||||||
|
default:
|
||||||
|
panic("unknown state")
|
||||||
|
case parseRespStateNormal:
|
||||||
|
select {
|
||||||
|
case tok, ok := <-tokChan:
|
||||||
|
if !ok {
|
||||||
|
ts.dlog("response finished")
|
||||||
|
return parseRespIterDone
|
||||||
|
}
|
||||||
|
if err, ok := tok.(net.Error); ok && err.Timeout() {
|
||||||
|
ts.cancelError = err
|
||||||
|
ts.dlog("got timeout error, sending attention signal to server")
|
||||||
|
return ts.sendAttention(ch)
|
||||||
|
}
|
||||||
|
// Pass the token along.
|
||||||
|
ch <- tok
|
||||||
|
return parseRespIterContinue
|
||||||
|
|
||||||
|
case <-ts.ctxDone:
|
||||||
|
ts.ctxDone = nil
|
||||||
|
ts.dlog("got cancel message, sending attention signal to server")
|
||||||
|
return ts.sendAttention(ch)
|
||||||
|
}
|
||||||
|
case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
|
||||||
|
select {
|
||||||
|
case tok, ok := <-tokChan:
|
||||||
|
if !ok {
|
||||||
|
ts.dlog("response finished but waiting for attention ack")
|
||||||
|
return parseRespIterNext
|
||||||
|
}
|
||||||
|
switch tok := tok.(type) {
|
||||||
|
default:
|
||||||
|
// Ignore all other tokens while waiting.
|
||||||
|
// The TDS spec says other tokens may arrive after an attention
|
||||||
|
// signal is sent. Ignore these tokens and continue looking for
|
||||||
|
// a DONE with attention confirm mark.
|
||||||
|
case doneStruct:
|
||||||
|
if tok.Status&doneAttn != 0 {
|
||||||
|
ts.dlog("got cancellation confirmation from server")
|
||||||
|
if ts.cancelError != nil {
|
||||||
|
ch <- ts.cancelError
|
||||||
|
ts.cancelError = nil
|
||||||
|
} else {
|
||||||
|
ch <- ctx.Err()
|
||||||
|
}
|
||||||
|
return parseRespIterDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// If an error happens during cancel, pass it along and just stop.
|
||||||
|
// We are uncertain to receive more tokens.
|
||||||
|
case error:
|
||||||
|
ch <- tok
|
||||||
|
ts.state = parseRespStateClosing
|
||||||
|
}
|
||||||
|
return parseRespIterContinue
|
||||||
|
case <-ts.ctxDone:
|
||||||
|
ts.ctxDone = nil
|
||||||
|
ts.state = parseRespStateClosing
|
||||||
|
return parseRespIterContinue
|
||||||
|
}
|
||||||
|
case parseRespStateClosing: // Wait for current token chan to close.
|
||||||
|
if _, ok := <-tokChan; !ok {
|
||||||
|
ts.dlog("response finished")
|
||||||
|
return parseRespIterDone
|
||||||
|
}
|
||||||
|
return parseRespIterContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
|
||||||
|
ts := &parseResp{
|
||||||
|
sess: sess,
|
||||||
|
ctxDone: ctx.Done(),
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// Ensure any remaining error is piped through
|
||||||
|
// or the query may look like it executed when it actually failed.
|
||||||
|
if ts.cancelError != nil {
|
||||||
|
ch <- ts.cancelError
|
||||||
|
ts.cancelError = nil
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Loop over multiple responses.
|
||||||
|
for {
|
||||||
|
ts.dlog("initiating response reading")
|
||||||
|
|
||||||
|
tokChan := make(chan tokenStruct)
|
||||||
|
go processSingleResponse(sess, tokChan, outs)
|
||||||
|
|
||||||
|
// Loop over multiple tokens in response.
|
||||||
|
tokensLoop:
|
||||||
|
for {
|
||||||
|
switch ts.iter(ctx, ch, tokChan) {
|
||||||
|
case parseRespIterContinue:
|
||||||
|
// Nothing, continue to next token.
|
||||||
|
case parseRespIterNext:
|
||||||
|
break tokensLoop
|
||||||
|
case parseRespIterDone:
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
// Code generated by "stringer -type token"; DO NOT EDIT
|
||||||
|
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
const (
|
||||||
|
_token_name_0 = "tokenReturnStatus"
|
||||||
|
_token_name_1 = "tokenColMetadata"
|
||||||
|
_token_name_2 = "tokenOrdertokenErrortokenInfo"
|
||||||
|
_token_name_3 = "tokenLoginAck"
|
||||||
|
_token_name_4 = "tokenRowtokenNbcRow"
|
||||||
|
_token_name_5 = "tokenEnvChange"
|
||||||
|
_token_name_6 = "tokenSSPI"
|
||||||
|
_token_name_7 = "tokenDonetokenDoneProctokenDoneInProc"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_token_index_0 = [...]uint8{0, 17}
|
||||||
|
_token_index_1 = [...]uint8{0, 16}
|
||||||
|
_token_index_2 = [...]uint8{0, 10, 20, 29}
|
||||||
|
_token_index_3 = [...]uint8{0, 13}
|
||||||
|
_token_index_4 = [...]uint8{0, 8, 19}
|
||||||
|
_token_index_5 = [...]uint8{0, 14}
|
||||||
|
_token_index_6 = [...]uint8{0, 9}
|
||||||
|
_token_index_7 = [...]uint8{0, 9, 22, 37}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (i token) String() string {
|
||||||
|
switch {
|
||||||
|
case i == 121:
|
||||||
|
return _token_name_0
|
||||||
|
case i == 129:
|
||||||
|
return _token_name_1
|
||||||
|
case 169 <= i && i <= 171:
|
||||||
|
i -= 169
|
||||||
|
return _token_name_2[_token_index_2[i]:_token_index_2[i+1]]
|
||||||
|
case i == 173:
|
||||||
|
return _token_name_3
|
||||||
|
case 209 <= i && i <= 210:
|
||||||
|
i -= 209
|
||||||
|
return _token_name_4[_token_index_4[i]:_token_index_4[i+1]]
|
||||||
|
case i == 227:
|
||||||
|
return _token_name_5
|
||||||
|
case i == 237:
|
||||||
|
return _token_name_6
|
||||||
|
case 253 <= i && i <= 255:
|
||||||
|
i -= 253
|
||||||
|
return _token_name_7[_token_index_7[i]:_token_index_7[i+1]]
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("token(%d)", i)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
|
package mssql
|
||||||
|
|
||||||
// Transaction Manager requests
|
// Transaction Manager requests
|
||||||
// http://msdn.microsoft.com/en-us/library/dd339887.aspx
|
// http://msdn.microsoft.com/en-us/library/dd339887.aspx
|
||||||
package mssql
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
@ -16,7 +17,18 @@ const (
|
||||||
tmSaveXact = 9
|
tmSaveXact = 9
|
||||||
)
|
)
|
||||||
|
|
||||||
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation uint8,
|
type isoLevel uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
isolationUseCurrent isoLevel = 0
|
||||||
|
isolationReadUncommited = 1
|
||||||
|
isolationReadCommited = 2
|
||||||
|
isolationRepeatableRead = 3
|
||||||
|
isolationSerializable = 4
|
||||||
|
isolationSnapshot = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel,
|
||||||
name string) (err error) {
|
name string) (err error) {
|
||||||
buf.BeginPacket(packTransMgrReq)
|
buf.BeginPacket(packTransMgrReq)
|
||||||
writeAllHeaders(buf, headers)
|
writeAllHeaders(buf, headers)
|
||||||
|
|
|
@ -6,8 +6,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/denisenkom/go-mssqldb/internal/cp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// fixed-length data types
|
// fixed-length data types
|
||||||
|
@ -66,6 +69,9 @@ const (
|
||||||
typeNText = 0x63
|
typeNText = 0x63
|
||||||
typeVariant = 0x62
|
typeVariant = 0x62
|
||||||
)
|
)
|
||||||
|
const PLP_NULL = 0xFFFFFFFFFFFFFFFF
|
||||||
|
const UNKNOWN_PLP_LEN = 0xFFFFFFFFFFFFFFFE
|
||||||
|
const PLP_TERMINATOR = 0x00000000
|
||||||
|
|
||||||
// TYPE_INFO rule
|
// TYPE_INFO rule
|
||||||
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
|
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
|
||||||
|
@ -75,11 +81,32 @@ type typeInfo struct {
|
||||||
Scale uint8
|
Scale uint8
|
||||||
Prec uint8
|
Prec uint8
|
||||||
Buffer []byte
|
Buffer []byte
|
||||||
Collation collation
|
Collation cp.Collation
|
||||||
|
UdtInfo udtInfo
|
||||||
|
XmlInfo xmlInfo
|
||||||
Reader func(ti *typeInfo, r *tdsBuffer) (res interface{})
|
Reader func(ti *typeInfo, r *tdsBuffer) (res interface{})
|
||||||
Writer func(w io.Writer, ti typeInfo, buf []byte) (err error)
|
Writer func(w io.Writer, ti typeInfo, buf []byte) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Common Language Runtime (CLR) Instances
|
||||||
|
// http://msdn.microsoft.com/en-us/library/dd357962.aspx
|
||||||
|
type udtInfo struct {
|
||||||
|
//MaxByteSize uint32
|
||||||
|
DBName string
|
||||||
|
SchemaName string
|
||||||
|
TypeName string
|
||||||
|
AssemblyQualifiedName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// XML Values
|
||||||
|
// http://msdn.microsoft.com/en-us/library/dd304764.aspx
|
||||||
|
type xmlInfo struct {
|
||||||
|
SchemaPresent uint8
|
||||||
|
DBName string
|
||||||
|
OwningSchema string
|
||||||
|
XmlSchemaCollection string
|
||||||
|
}
|
||||||
|
|
||||||
func readTypeInfo(r *tdsBuffer) (res typeInfo) {
|
func readTypeInfo(r *tdsBuffer) (res typeInfo) {
|
||||||
res.TypeId = r.byte()
|
res.TypeId = r.byte()
|
||||||
switch res.TypeId {
|
switch res.TypeId {
|
||||||
|
@ -114,7 +141,8 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
|
||||||
switch ti.TypeId {
|
switch ti.TypeId {
|
||||||
case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4,
|
case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4,
|
||||||
typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8:
|
typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8:
|
||||||
// those are fixed length types
|
// those are fixed length
|
||||||
|
ti.Writer = writeFixedType
|
||||||
default: // all others are VARLENTYPE
|
default: // all others are VARLENTYPE
|
||||||
err = writeVarLen(w, ti)
|
err = writeVarLen(w, ti)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -124,19 +152,25 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
||||||
|
_, err = w.Write(buf)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
|
func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
|
||||||
switch ti.TypeId {
|
switch ti.TypeId {
|
||||||
case typeDateN:
|
case typeDateN:
|
||||||
|
ti.Writer = writeByteLenType
|
||||||
case typeTimeN, typeDateTime2N, typeDateTimeOffsetN:
|
case typeTimeN, typeDateTime2N, typeDateTimeOffsetN:
|
||||||
if err = binary.Write(w, binary.LittleEndian, ti.Scale); err != nil {
|
if err = binary.Write(w, binary.LittleEndian, ti.Scale); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ti.Writer = writeByteLenType
|
ti.Writer = writeByteLenType
|
||||||
case typeGuid, typeIntN, typeDecimal, typeNumeric,
|
case typeIntN, typeDecimal, typeNumeric,
|
||||||
typeBitN, typeDecimalN, typeNumericN, typeFltN,
|
typeBitN, typeDecimalN, typeNumericN, typeFltN,
|
||||||
typeMoneyN, typeDateTimeN, typeChar,
|
typeMoneyN, typeDateTimeN, typeChar,
|
||||||
typeVarChar, typeBinary, typeVarBinary:
|
typeVarChar, typeBinary, typeVarBinary:
|
||||||
|
|
||||||
// byle len types
|
// byle len types
|
||||||
if ti.Size > 0xff {
|
if ti.Size > 0xff {
|
||||||
panic("Invalid size for BYLELEN_TYPE")
|
panic("Invalid size for BYLELEN_TYPE")
|
||||||
|
@ -156,6 +190,14 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ti.Writer = writeByteLenType
|
ti.Writer = writeByteLenType
|
||||||
|
case typeGuid:
|
||||||
|
if !(ti.Size == 0x10 || ti.Size == 0x00) {
|
||||||
|
panic("Invalid size for BYLELEN_TYPE")
|
||||||
|
}
|
||||||
|
if err = binary.Write(w, binary.LittleEndian, uint8(ti.Size)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ti.Writer = writeByteLenType
|
||||||
case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar,
|
case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar,
|
||||||
typeNVarChar, typeNChar, typeXml, typeUdt:
|
typeNVarChar, typeNChar, typeXml, typeUdt:
|
||||||
// short len types
|
// short len types
|
||||||
|
@ -176,14 +218,19 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case typeXml:
|
case typeXml:
|
||||||
var schemapresent uint8 = 0
|
if err = binary.Write(w, binary.LittleEndian, ti.XmlInfo.SchemaPresent); err != nil {
|
||||||
if err = binary.Write(w, binary.LittleEndian, schemapresent); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case typeText, typeImage, typeNText, typeVariant:
|
case typeText, typeImage, typeNText, typeVariant:
|
||||||
// LONGLEN_TYPE
|
// LONGLEN_TYPE
|
||||||
panic("LONGLEN_TYPE not implemented")
|
if err = binary.Write(w, binary.LittleEndian, uint32(ti.Size)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = writeCollation(w, ti.Collation); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ti.Writer = writeLongLenType
|
||||||
default:
|
default:
|
||||||
panic("Invalid type")
|
panic("Invalid type")
|
||||||
}
|
}
|
||||||
|
@ -207,7 +254,7 @@ func decodeDateTime(buf []byte) time.Time {
|
||||||
0, 0, secs, ns, time.UTC)
|
0, 0, secs, ns, time.UTC)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readFixedType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} {
|
||||||
r.ReadFull(ti.Buffer)
|
r.ReadFull(ti.Buffer)
|
||||||
buf := ti.Buffer
|
buf := ti.Buffer
|
||||||
switch ti.TypeId {
|
switch ti.TypeId {
|
||||||
|
@ -241,12 +288,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
||||||
panic("shoulnd't get here")
|
panic("shoulnd't get here")
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} {
|
||||||
_, err = w.Write(buf)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func readByteLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
|
||||||
size := r.byte()
|
size := r.byte()
|
||||||
if size == 0 {
|
if size == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
@ -305,6 +347,10 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
||||||
default:
|
default:
|
||||||
badStreamPanicf("Invalid size for MONEYNTYPE")
|
badStreamPanicf("Invalid size for MONEYNTYPE")
|
||||||
}
|
}
|
||||||
|
case typeDateTim4:
|
||||||
|
return decodeDateTim4(buf)
|
||||||
|
case typeDateTime:
|
||||||
|
return decodeDateTime(buf)
|
||||||
case typeDateTimeN:
|
case typeDateTimeN:
|
||||||
switch len(buf) {
|
switch len(buf) {
|
||||||
case 4:
|
case 4:
|
||||||
|
@ -341,7 +387,7 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func readShortLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} {
|
||||||
size := r.uint16()
|
size := r.uint16()
|
||||||
if size == 0xffff {
|
if size == 0xffff {
|
||||||
return nil
|
return nil
|
||||||
|
@ -384,7 +430,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func readLongLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} {
|
||||||
// information about this format can be found here:
|
// information about this format can be found here:
|
||||||
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
|
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
|
||||||
// and here:
|
// and here:
|
||||||
|
@ -415,10 +461,51 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
||||||
}
|
}
|
||||||
panic("shoulnd't get here")
|
panic("shoulnd't get here")
|
||||||
}
|
}
|
||||||
|
func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
||||||
|
//textptr
|
||||||
|
err = binary.Write(w, binary.LittleEndian, byte(0x10))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//timestamp?
|
||||||
|
err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = binary.Write(w, binary.LittleEndian, uint32(ti.Size))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = w.Write(buf)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func readCollation(r *tdsBuffer) (res cp.Collation) {
|
||||||
|
res.LcidAndFlags = r.uint32()
|
||||||
|
res.SortId = r.byte()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeCollation(w io.Writer, col cp.Collation) (err error) {
|
||||||
|
if err = binary.Write(w, binary.LittleEndian, col.LcidAndFlags); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = binary.Write(w, binary.LittleEndian, col.SortId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// reads variant value
|
// reads variant value
|
||||||
// http://msdn.microsoft.com/en-us/library/dd303302.aspx
|
// http://msdn.microsoft.com/en-us/library/dd303302.aspx
|
||||||
func readVariantType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
|
||||||
size := r.int32()
|
size := r.int32()
|
||||||
if size == 0 {
|
if size == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
@ -510,14 +597,14 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
||||||
|
|
||||||
// partially length prefixed stream
|
// partially length prefixed stream
|
||||||
// http://msdn.microsoft.com/en-us/library/dd340469.aspx
|
// http://msdn.microsoft.com/en-us/library/dd340469.aspx
|
||||||
func readPLPType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
|
||||||
size := r.uint64()
|
size := r.uint64()
|
||||||
var buf *bytes.Buffer
|
var buf *bytes.Buffer
|
||||||
switch size {
|
switch size {
|
||||||
case 0xffffffffffffffff:
|
case PLP_NULL:
|
||||||
// null
|
// null
|
||||||
return nil
|
return nil
|
||||||
case 0xfffffffffffffffe:
|
case UNKNOWN_PLP_LEN:
|
||||||
// size unknown
|
// size unknown
|
||||||
buf = bytes.NewBuffer(make([]byte, 0, 1000))
|
buf = bytes.NewBuffer(make([]byte, 0, 1000))
|
||||||
default:
|
default:
|
||||||
|
@ -548,15 +635,16 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) (res interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
||||||
if err = binary.Write(w, binary.LittleEndian, uint64(len(buf))); err != nil {
|
if err = binary.Write(w, binary.LittleEndian, uint64(UNKNOWN_PLP_LEN)); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
chunksize := uint32(len(buf))
|
chunksize := uint32(len(buf))
|
||||||
if err = binary.Write(w, binary.LittleEndian, chunksize); err != nil {
|
if chunksize == 0 {
|
||||||
|
err = binary.Write(w, binary.LittleEndian, uint32(PLP_TERMINATOR))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if chunksize == 0 {
|
if err = binary.Write(w, binary.LittleEndian, chunksize); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err = w.Write(buf[:chunksize]); err != nil {
|
if _, err = w.Write(buf[:chunksize]); err != nil {
|
||||||
|
@ -606,19 +694,27 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) {
|
||||||
}
|
}
|
||||||
ti.Reader = readByteLenType
|
ti.Reader = readByteLenType
|
||||||
case typeXml:
|
case typeXml:
|
||||||
schemapresent := r.byte()
|
ti.XmlInfo.SchemaPresent = r.byte()
|
||||||
if schemapresent != 0 {
|
if ti.XmlInfo.SchemaPresent != 0 {
|
||||||
// just ignore this for now
|
|
||||||
// dbname
|
// dbname
|
||||||
r.BVarChar()
|
ti.XmlInfo.DBName = r.BVarChar()
|
||||||
// owning schema
|
// owning schema
|
||||||
r.BVarChar()
|
ti.XmlInfo.OwningSchema = r.BVarChar()
|
||||||
// xml schema collection
|
// xml schema collection
|
||||||
r.UsVarChar()
|
ti.XmlInfo.XmlSchemaCollection = r.UsVarChar()
|
||||||
}
|
}
|
||||||
ti.Reader = readPLPType
|
ti.Reader = readPLPType
|
||||||
|
case typeUdt:
|
||||||
|
ti.Size = int(r.uint16())
|
||||||
|
ti.UdtInfo.DBName = r.BVarChar()
|
||||||
|
ti.UdtInfo.SchemaName = r.BVarChar()
|
||||||
|
ti.UdtInfo.TypeName = r.BVarChar()
|
||||||
|
ti.UdtInfo.AssemblyQualifiedName = r.UsVarChar()
|
||||||
|
|
||||||
|
ti.Buffer = make([]byte, ti.Size)
|
||||||
|
ti.Reader = readPLPType
|
||||||
case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar,
|
case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar,
|
||||||
typeNVarChar, typeNChar, typeUdt:
|
typeNVarChar, typeNChar:
|
||||||
// short len types
|
// short len types
|
||||||
ti.Size = int(r.uint16())
|
ti.Size = int(r.uint16())
|
||||||
switch ti.TypeId {
|
switch ti.TypeId {
|
||||||
|
@ -701,7 +797,8 @@ func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte {
|
||||||
|
|
||||||
// http://msdn.microsoft.com/en-us/library/ee780895.aspx
|
// http://msdn.microsoft.com/en-us/library/ee780895.aspx
|
||||||
func decodeDateInt(buf []byte) (days int) {
|
func decodeDateInt(buf []byte) (days int) {
|
||||||
return int(buf[0]) + int(buf[1])*256 + int(buf[2])*256*256
|
days = int(buf[0]) + int(buf[1])*256 + int(buf[2])*256*256
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeDate(buf []byte) time.Time {
|
func decodeDate(buf []byte) time.Time {
|
||||||
|
@ -767,8 +864,8 @@ func dateTime2(t time.Time) (days int32, ns int64) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeChar(col collation, buf []byte) string {
|
func decodeChar(col cp.Collation, buf []byte) string {
|
||||||
return charset2utf8(col, buf)
|
return cp.CharsetToUTF8(col, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeUcs2(buf []byte) string {
|
func decodeUcs2(buf []byte) string {
|
||||||
|
@ -787,12 +884,127 @@ func decodeXml(ti typeInfo, buf []byte) string {
|
||||||
return decodeUcs2(buf)
|
return decodeUcs2(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeUdt(ti typeInfo, buf []byte) int {
|
func decodeUdt(ti typeInfo, buf []byte) []byte {
|
||||||
panic("Not implemented")
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// makes go/sql type instance as described below
|
||||||
|
// It should return
|
||||||
|
// the value type that can be used to scan types into. For example, the database
|
||||||
|
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
|
||||||
|
func makeGoLangScanType(ti typeInfo) reflect.Type {
|
||||||
|
switch ti.TypeId {
|
||||||
|
case typeInt1:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case typeInt2:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case typeInt4:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case typeInt8:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case typeFlt4:
|
||||||
|
return reflect.TypeOf(float64(0))
|
||||||
|
case typeIntN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 1:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case 2:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case 4:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
case 8:
|
||||||
|
return reflect.TypeOf(int64(0))
|
||||||
|
default:
|
||||||
|
panic("invalid size of INTNTYPE")
|
||||||
|
}
|
||||||
|
case typeFlt8:
|
||||||
|
return reflect.TypeOf(float64(0))
|
||||||
|
case typeFltN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return reflect.TypeOf(float64(0))
|
||||||
|
case 8:
|
||||||
|
return reflect.TypeOf(float64(0))
|
||||||
|
default:
|
||||||
|
panic("invalid size of FLNNTYPE")
|
||||||
|
}
|
||||||
|
case typeBigVarBin:
|
||||||
|
return reflect.TypeOf([]byte{})
|
||||||
|
case typeVarChar:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeNVarChar:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeBit, typeBitN:
|
||||||
|
return reflect.TypeOf(true)
|
||||||
|
case typeDecimalN, typeNumericN:
|
||||||
|
return reflect.TypeOf([]byte{})
|
||||||
|
case typeMoney, typeMoney4, typeMoneyN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return reflect.TypeOf([]byte{})
|
||||||
|
case 8:
|
||||||
|
return reflect.TypeOf([]byte{})
|
||||||
|
default:
|
||||||
|
panic("invalid size of MONEYN")
|
||||||
|
}
|
||||||
|
case typeDateTim4:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case typeDateTime:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case typeDateTimeN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case 8:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
default:
|
||||||
|
panic("invalid size of DATETIMEN")
|
||||||
|
}
|
||||||
|
case typeDateTime2N:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case typeDateN:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case typeTimeN:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case typeDateTimeOffsetN:
|
||||||
|
return reflect.TypeOf(time.Time{})
|
||||||
|
case typeBigVarChar:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeBigChar:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeNChar:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeGuid:
|
||||||
|
return reflect.TypeOf([]byte{})
|
||||||
|
case typeXml:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeText:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeNText:
|
||||||
|
return reflect.TypeOf("")
|
||||||
|
case typeImage:
|
||||||
|
return reflect.TypeOf([]byte{})
|
||||||
|
case typeBigBinary:
|
||||||
|
return reflect.TypeOf([]byte{})
|
||||||
|
case typeVariant:
|
||||||
|
return reflect.TypeOf(nil)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeDecl(ti typeInfo) string {
|
func makeDecl(ti typeInfo) string {
|
||||||
switch ti.TypeId {
|
switch ti.TypeId {
|
||||||
|
case typeNull:
|
||||||
|
// maybe we should use something else here
|
||||||
|
// this is tested in TestNull
|
||||||
|
return "nvarchar(1)"
|
||||||
|
case typeInt1:
|
||||||
|
return "tinyint"
|
||||||
|
case typeInt2:
|
||||||
|
return "smallint"
|
||||||
|
case typeInt4:
|
||||||
|
return "int"
|
||||||
case typeInt8:
|
case typeInt8:
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case typeFlt4:
|
case typeFlt4:
|
||||||
|
@ -821,24 +1033,415 @@ func makeDecl(ti typeInfo) string {
|
||||||
default:
|
default:
|
||||||
panic("invalid size of FLNNTYPE")
|
panic("invalid size of FLNNTYPE")
|
||||||
}
|
}
|
||||||
|
case typeDecimal, typeDecimalN:
|
||||||
|
return fmt.Sprintf("decimal(%d, %d)", ti.Prec, ti.Scale)
|
||||||
|
case typeNumeric, typeNumericN:
|
||||||
|
return fmt.Sprintf("numeric(%d, %d)", ti.Prec, ti.Scale)
|
||||||
|
case typeMoney4:
|
||||||
|
return "smallmoney"
|
||||||
|
case typeMoney:
|
||||||
|
return "money"
|
||||||
|
case typeMoneyN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return "smallmoney"
|
||||||
|
case 8:
|
||||||
|
return "money"
|
||||||
|
default:
|
||||||
|
panic("invalid size of MONEYNTYPE")
|
||||||
|
}
|
||||||
case typeBigVarBin:
|
case typeBigVarBin:
|
||||||
if ti.Size > 8000 || ti.Size == 0 {
|
if ti.Size > 8000 || ti.Size == 0 {
|
||||||
return fmt.Sprintf("varbinary(max)")
|
return "varbinary(max)"
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("varbinary(%d)", ti.Size)
|
return fmt.Sprintf("varbinary(%d)", ti.Size)
|
||||||
}
|
}
|
||||||
|
case typeNChar:
|
||||||
|
return fmt.Sprintf("nchar(%d)", ti.Size/2)
|
||||||
|
case typeBigChar, typeChar:
|
||||||
|
return fmt.Sprintf("char(%d)", ti.Size)
|
||||||
|
case typeBigVarChar, typeVarChar:
|
||||||
|
if ti.Size > 4000 || ti.Size == 0 {
|
||||||
|
return fmt.Sprintf("varchar(max)")
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("varchar(%d)", ti.Size)
|
||||||
|
}
|
||||||
case typeNVarChar:
|
case typeNVarChar:
|
||||||
if ti.Size > 8000 || ti.Size == 0 {
|
if ti.Size > 8000 || ti.Size == 0 {
|
||||||
return fmt.Sprintf("nvarchar(max)")
|
return "nvarchar(max)"
|
||||||
} else {
|
} else {
|
||||||
return fmt.Sprintf("nvarchar(%d)", ti.Size/2)
|
return fmt.Sprintf("nvarchar(%d)", ti.Size/2)
|
||||||
}
|
}
|
||||||
case typeBit, typeBitN:
|
case typeBit, typeBitN:
|
||||||
return "bit"
|
return "bit"
|
||||||
case typeDateTimeN:
|
case typeDateN:
|
||||||
|
return "date"
|
||||||
|
case typeDateTim4:
|
||||||
|
return "smalldatetime"
|
||||||
|
case typeDateTime:
|
||||||
return "datetime"
|
return "datetime"
|
||||||
|
case typeDateTimeN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return "smalldatetime"
|
||||||
|
case 8:
|
||||||
|
return "datetime"
|
||||||
|
default:
|
||||||
|
panic("invalid size of DATETIMNTYPE")
|
||||||
|
}
|
||||||
|
case typeDateTime2N:
|
||||||
|
return fmt.Sprintf("datetime2(%d)", ti.Scale)
|
||||||
case typeDateTimeOffsetN:
|
case typeDateTimeOffsetN:
|
||||||
return fmt.Sprintf("datetimeoffset(%d)", ti.Scale)
|
return fmt.Sprintf("datetimeoffset(%d)", ti.Scale)
|
||||||
|
case typeText:
|
||||||
|
return "text"
|
||||||
|
case typeNText:
|
||||||
|
return "ntext"
|
||||||
|
case typeUdt:
|
||||||
|
return ti.UdtInfo.TypeName
|
||||||
|
case typeGuid:
|
||||||
|
return "uniqueidentifier"
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("not implemented makeDecl for type %#x", ti.TypeId))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makes go/sql type name as described below
|
||||||
|
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
|
||||||
|
// database system type name without the length. Type names should be uppercase.
|
||||||
|
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
|
||||||
|
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
|
||||||
|
// "TIMESTAMP".
|
||||||
|
func makeGoLangTypeName(ti typeInfo) string {
|
||||||
|
switch ti.TypeId {
|
||||||
|
case typeInt1:
|
||||||
|
return "TINYINT"
|
||||||
|
case typeInt2:
|
||||||
|
return "SMALLINT"
|
||||||
|
case typeInt4:
|
||||||
|
return "INT"
|
||||||
|
case typeInt8:
|
||||||
|
return "BIGINT"
|
||||||
|
case typeFlt4:
|
||||||
|
return "REAL"
|
||||||
|
case typeIntN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 1:
|
||||||
|
return "TINYINT"
|
||||||
|
case 2:
|
||||||
|
return "SMALLINT"
|
||||||
|
case 4:
|
||||||
|
return "INT"
|
||||||
|
case 8:
|
||||||
|
return "BIGINT"
|
||||||
|
default:
|
||||||
|
panic("invalid size of INTNTYPE")
|
||||||
|
}
|
||||||
|
case typeFlt8:
|
||||||
|
return "FLOAT"
|
||||||
|
case typeFltN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return "REAL"
|
||||||
|
case 8:
|
||||||
|
return "FLOAT"
|
||||||
|
default:
|
||||||
|
panic("invalid size of FLNNTYPE")
|
||||||
|
}
|
||||||
|
case typeBigVarBin:
|
||||||
|
return "VARBINARY"
|
||||||
|
case typeVarChar:
|
||||||
|
return "VARCHAR"
|
||||||
|
case typeNVarChar:
|
||||||
|
return "NVARCHAR"
|
||||||
|
case typeBit, typeBitN:
|
||||||
|
return "BIT"
|
||||||
|
case typeDecimalN, typeNumericN:
|
||||||
|
return "DECIMAL"
|
||||||
|
case typeMoney, typeMoney4, typeMoneyN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return "SMALLMONEY"
|
||||||
|
case 8:
|
||||||
|
return "MONEY"
|
||||||
|
default:
|
||||||
|
panic("invalid size of MONEYN")
|
||||||
|
}
|
||||||
|
case typeDateTim4:
|
||||||
|
return "SMALLDATETIME"
|
||||||
|
case typeDateTime:
|
||||||
|
return "DATETIME"
|
||||||
|
case typeDateTimeN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return "SMALLDATETIME"
|
||||||
|
case 8:
|
||||||
|
return "DATETIME"
|
||||||
|
default:
|
||||||
|
panic("invalid size of DATETIMEN")
|
||||||
|
}
|
||||||
|
case typeDateTime2N:
|
||||||
|
return "DATETIME2"
|
||||||
|
case typeDateN:
|
||||||
|
return "DATE"
|
||||||
|
case typeTimeN:
|
||||||
|
return "TIME"
|
||||||
|
case typeDateTimeOffsetN:
|
||||||
|
return "DATETIMEOFFSET"
|
||||||
|
case typeBigVarChar:
|
||||||
|
return "VARCHAR"
|
||||||
|
case typeBigChar:
|
||||||
|
return "CHAR"
|
||||||
|
case typeNChar:
|
||||||
|
return "NCHAR"
|
||||||
|
case typeGuid:
|
||||||
|
return "UNIQUEIDENTIFIER"
|
||||||
|
case typeXml:
|
||||||
|
return "XML"
|
||||||
|
case typeText:
|
||||||
|
return "TEXT"
|
||||||
|
case typeNText:
|
||||||
|
return "NTEXT"
|
||||||
|
case typeImage:
|
||||||
|
return "IMAGE"
|
||||||
|
case typeVariant:
|
||||||
|
return "SQL_VARIANT"
|
||||||
|
case typeBigBinary:
|
||||||
|
return "BINARY"
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makes go/sql type length as described below
|
||||||
|
// It should return the length
|
||||||
|
// of the column type if the column is a variable length type. If the column is
|
||||||
|
// not a variable length type ok should return false.
|
||||||
|
// If length is not limited other than system limits, it should return math.MaxInt64.
|
||||||
|
// The following are examples of returned values for various types:
|
||||||
|
// TEXT (math.MaxInt64, true)
|
||||||
|
// varchar(10) (10, true)
|
||||||
|
// nvarchar(10) (10, true)
|
||||||
|
// decimal (0, false)
|
||||||
|
// int (0, false)
|
||||||
|
// bytea(30) (30, true)
|
||||||
|
func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
|
||||||
|
switch ti.TypeId {
|
||||||
|
case typeInt1:
|
||||||
|
return 0, false
|
||||||
|
case typeInt2:
|
||||||
|
return 0, false
|
||||||
|
case typeInt4:
|
||||||
|
return 0, false
|
||||||
|
case typeInt8:
|
||||||
|
return 0, false
|
||||||
|
case typeFlt4:
|
||||||
|
return 0, false
|
||||||
|
case typeIntN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 1:
|
||||||
|
return 0, false
|
||||||
|
case 2:
|
||||||
|
return 0, false
|
||||||
|
case 4:
|
||||||
|
return 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of INTNTYPE")
|
||||||
|
}
|
||||||
|
case typeFlt8:
|
||||||
|
return 0, false
|
||||||
|
case typeFltN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of FLNNTYPE")
|
||||||
|
}
|
||||||
|
case typeBit, typeBitN:
|
||||||
|
return 0, false
|
||||||
|
case typeDecimalN, typeNumericN:
|
||||||
|
return 0, false
|
||||||
|
case typeMoney, typeMoney4, typeMoneyN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of MONEYN")
|
||||||
|
}
|
||||||
|
case typeDateTim4, typeDateTime:
|
||||||
|
return 0, false
|
||||||
|
case typeDateTimeN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of DATETIMEN")
|
||||||
|
}
|
||||||
|
case typeDateTime2N:
|
||||||
|
return 0, false
|
||||||
|
case typeDateN:
|
||||||
|
return 0, false
|
||||||
|
case typeTimeN:
|
||||||
|
return 0, false
|
||||||
|
case typeDateTimeOffsetN:
|
||||||
|
return 0, false
|
||||||
|
case typeBigVarBin:
|
||||||
|
if ti.Size == 0xffff {
|
||||||
|
return 2147483645, true
|
||||||
|
} else {
|
||||||
|
return int64(ti.Size), true
|
||||||
|
}
|
||||||
|
case typeVarChar:
|
||||||
|
return int64(ti.Size), true
|
||||||
|
case typeBigVarChar:
|
||||||
|
if ti.Size == 0xffff {
|
||||||
|
return 2147483645, true
|
||||||
|
} else {
|
||||||
|
return int64(ti.Size), true
|
||||||
|
}
|
||||||
|
case typeBigChar:
|
||||||
|
return int64(ti.Size), true
|
||||||
|
case typeNVarChar:
|
||||||
|
if ti.Size == 0xffff {
|
||||||
|
return 2147483645 / 2, true
|
||||||
|
} else {
|
||||||
|
return int64(ti.Size) / 2, true
|
||||||
|
}
|
||||||
|
case typeNChar:
|
||||||
|
return int64(ti.Size) / 2, true
|
||||||
|
case typeGuid:
|
||||||
|
return 0, false
|
||||||
|
case typeXml:
|
||||||
|
return 1073741822, true
|
||||||
|
case typeText:
|
||||||
|
return 2147483647, true
|
||||||
|
case typeNText:
|
||||||
|
return 1073741823, true
|
||||||
|
case typeImage:
|
||||||
|
return 2147483647, true
|
||||||
|
case typeVariant:
|
||||||
|
return 0, false
|
||||||
|
case typeBigBinary:
|
||||||
|
return 0, false
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makes go/sql type precision and scale as described below
|
||||||
|
// It should return the length
|
||||||
|
// of the column type if the column is a variable length type. If the column is
|
||||||
|
// not a variable length type ok should return false.
|
||||||
|
// If length is not limited other than system limits, it should return math.MaxInt64.
|
||||||
|
// The following are examples of returned values for various types:
|
||||||
|
// TEXT (math.MaxInt64, true)
|
||||||
|
// varchar(10) (10, true)
|
||||||
|
// nvarchar(10) (10, true)
|
||||||
|
// decimal (0, false)
|
||||||
|
// int (0, false)
|
||||||
|
// bytea(30) (30, true)
|
||||||
|
func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) {
|
||||||
|
switch ti.TypeId {
|
||||||
|
case typeInt1:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeInt2:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeInt4:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeInt8:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeFlt4:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeIntN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 1:
|
||||||
|
return 0, 0, false
|
||||||
|
case 2:
|
||||||
|
return 0, 0, false
|
||||||
|
case 4:
|
||||||
|
return 0, 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of INTNTYPE")
|
||||||
|
}
|
||||||
|
case typeFlt8:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeFltN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return 0, 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of FLNNTYPE")
|
||||||
|
}
|
||||||
|
case typeBit, typeBitN:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeDecimalN, typeNumericN:
|
||||||
|
return int64(ti.Prec), int64(ti.Scale), true
|
||||||
|
case typeMoney, typeMoney4, typeMoneyN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return 0, 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of MONEYN")
|
||||||
|
}
|
||||||
|
case typeDateTim4, typeDateTime:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeDateTimeN:
|
||||||
|
switch ti.Size {
|
||||||
|
case 4:
|
||||||
|
return 0, 0, false
|
||||||
|
case 8:
|
||||||
|
return 0, 0, false
|
||||||
|
default:
|
||||||
|
panic("invalid size of DATETIMEN")
|
||||||
|
}
|
||||||
|
case typeDateTime2N:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeDateN:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeTimeN:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeDateTimeOffsetN:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeBigVarBin:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeVarChar:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeBigVarChar:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeBigChar:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeNVarChar:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeNChar:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeGuid:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeXml:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeText:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeNText:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeImage:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeVariant:
|
||||||
|
return 0, 0, false
|
||||||
|
case typeBigBinary:
|
||||||
|
return 0, 0, false
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
|
panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId))
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UniqueIdentifier [16]byte
|
||||||
|
|
||||||
|
func (u *UniqueIdentifier) Scan(v interface{}) error {
|
||||||
|
reverse := func(b []byte) {
|
||||||
|
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
b[i], b[j] = b[j], b[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch vt := v.(type) {
|
||||||
|
case []byte:
|
||||||
|
if len(vt) != 16 {
|
||||||
|
return errors.New("mssql: invalid UniqueIdentifier length")
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw UniqueIdentifier
|
||||||
|
|
||||||
|
copy(raw[:], vt)
|
||||||
|
|
||||||
|
reverse(raw[0:4])
|
||||||
|
reverse(raw[4:6])
|
||||||
|
reverse(raw[6:8])
|
||||||
|
*u = raw
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
if len(vt) != 36 {
|
||||||
|
return errors.New("mssql: invalid UniqueIdentifier string length")
|
||||||
|
}
|
||||||
|
|
||||||
|
b := []byte(vt)
|
||||||
|
for i, c := range b {
|
||||||
|
switch c {
|
||||||
|
case '-':
|
||||||
|
b = append(b[:i], b[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := hex.Decode(u[:], []byte(b))
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("mssql: cannot convert %T to UniqueIdentifier", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UniqueIdentifier) Value() (driver.Value, error) {
|
||||||
|
reverse := func(b []byte) {
|
||||||
|
for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
b[i], b[j] = b[j], b[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := make([]byte, len(u))
|
||||||
|
copy(raw, u[:])
|
||||||
|
|
||||||
|
reverse(raw[0:4])
|
||||||
|
reverse(raw[4:6])
|
||||||
|
reverse(raw[6:8])
|
||||||
|
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UniqueIdentifier) String() string {
|
||||||
|
return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:])
|
||||||
|
}
|
|
@ -85,8 +85,9 @@ github.com/couchbase/vellum/utf8
|
||||||
github.com/couchbaselabs/go-couchbase
|
github.com/couchbaselabs/go-couchbase
|
||||||
# github.com/davecgh/go-spew v1.1.1
|
# github.com/davecgh/go-spew v1.1.1
|
||||||
github.com/davecgh/go-spew/spew
|
github.com/davecgh/go-spew/spew
|
||||||
# github.com/denisenkom/go-mssqldb v0.0.0-20190121005146-b04fd42d9952 => github.com/denisenkom/go-mssqldb v0.0.0-20161128230840-e32ca5036449
|
# github.com/denisenkom/go-mssqldb v0.0.0-20190121005146-b04fd42d9952 => github.com/denisenkom/go-mssqldb v0.0.0-20180314172330-6a30f4e59a44
|
||||||
github.com/denisenkom/go-mssqldb
|
github.com/denisenkom/go-mssqldb
|
||||||
|
github.com/denisenkom/go-mssqldb/internal/cp
|
||||||
# github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac
|
# github.com/dgrijalva/jwt-go v0.0.0-20161101193935-9ed569b5d1ac
|
||||||
github.com/dgrijalva/jwt-go
|
github.com/dgrijalva/jwt-go
|
||||||
# github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712
|
# github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712
|
||||||
|
|
Loading…
Reference in New Issue