Skip to content

Commit 0461d46

Browse files
NikitaDefGavrilov.Nikita2
and
Gavrilov.Nikita2
authored
Add support TVP identity (#771)
* add support identity type Co-authored-by: Gavrilov.Nikita2 <[email protected]>
1 parent c7ddec1 commit 0461d46

File tree

3 files changed

+174
-20
lines changed

3 files changed

+174
-20
lines changed

tvp_go19.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
//go:build go1.9
12
// +build go1.9
23

34
package mssql
@@ -16,6 +17,7 @@ import (
1617
const (
1718
jsonTag = "json"
1819
tvpTag = "tvp"
20+
tvpIdentity = "@identity"
1921
skipTagValue = "-"
2022
sqlSeparator = "."
2123
)
@@ -29,7 +31,7 @@ var (
2931
ErrorWrongTyping = errors.New("the number of elements in columnStr and tvpFieldIndexes do not align")
3032
)
3133

32-
//TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
34+
// TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
3335
type TVP struct {
3436
//TypeName mustn't be default value
3537
TypeName string
@@ -76,8 +78,8 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
7678
binary.Write(buf, binary.LittleEndian, uint16(len(columnStr)))
7779

7880
for i, column := range columnStr {
79-
binary.Write(buf, binary.LittleEndian, uint32(column.UserType))
80-
binary.Write(buf, binary.LittleEndian, uint16(column.Flags))
81+
binary.Write(buf, binary.LittleEndian, column.UserType)
82+
binary.Write(buf, binary.LittleEndian, column.Flags)
8183
writeTypeInfo(buf, &columnStr[i].ti)
8284
writeBVarChar(buf, "")
8385
}
@@ -96,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
9698
refStr := reflect.ValueOf(val.Index(i).Interface())
9799
buf.WriteByte(_TVP_ROW_TOKEN)
98100
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
101+
if columnStr[columnStrIdx].Flags == fDefault {
102+
continue
103+
}
99104
field := refStr.Field(fieldIdx)
100105
tvpVal := field.Interface()
101106
if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
@@ -135,6 +140,11 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
135140
}
136141

137142
func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
143+
type fieldDetailStore struct {
144+
defaultValue interface{}
145+
isIdentity bool
146+
}
147+
138148
val := reflect.ValueOf(tvp.Value)
139149
var firstRow interface{}
140150
if val.Len() != 0 {
@@ -145,7 +155,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
145155

146156
tvpRow := reflect.TypeOf(firstRow)
147157
columnCount := tvpRow.NumField()
148-
defaultValues := make([]interface{}, 0, columnCount)
158+
defaultValues := make([]fieldDetailStore, 0, columnCount)
149159
tvpFieldIndexes := make([]int, 0, columnCount)
150160
for i := 0; i < columnCount; i++ {
151161
field := tvpRow.Field(i)
@@ -155,12 +165,19 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
155165
continue
156166
}
157167
tvpFieldIndexes = append(tvpFieldIndexes, i)
168+
isIdentity := tvpTagValue == tvpIdentity
158169
if field.Type.Kind() == reflect.Ptr {
159170
v := reflect.New(field.Type.Elem())
160-
defaultValues = append(defaultValues, v.Interface())
171+
defaultValues = append(defaultValues, fieldDetailStore{
172+
defaultValue: v.Interface(),
173+
isIdentity: isIdentity,
174+
})
161175
continue
162176
}
163-
defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface()))
177+
defaultValues = append(defaultValues, fieldDetailStore{
178+
defaultValue: tvp.createZeroType(reflect.Zero(field.Type).Interface()),
179+
isIdentity: isIdentity,
180+
})
164181
}
165182

166183
if columnCount-len(tvpFieldIndexes) == columnCount {
@@ -176,9 +193,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
176193

177194
columnConfiguration := make([]columnStruct, 0, columnCount)
178195
for index, val := range defaultValues {
179-
cval, err := convertInputParameter(val)
196+
cval, err := convertInputParameter(val.defaultValue)
180197
if err != nil {
181-
return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err)
198+
return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val.defaultValue, err)
182199
}
183200
param, err := stmt.makeParam(cval)
184201
if err != nil {
@@ -187,6 +204,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
187204
column := columnStruct{
188205
ti: param.ti,
189206
}
207+
if val.isIdentity {
208+
column.Flags = fDefault
209+
}
190210
switch param.ti.TypeId {
191211
case typeNVarChar, typeBigVarBin:
192212
column.ti.Size = 0

tvp_go19_db_test.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,129 @@ func TestTVPUnsigned(t *testing.T) {
13481348
t.Errorf("third result set had wrong value expected: %s actual: %s", "test", result3)
13491349
}
13501350
}
1351+
1352+
func TestTVPIdentity(t *testing.T) {
1353+
type TvpIdentityExample struct {
1354+
ID int `tvp:"@identity"`
1355+
Message string
1356+
}
1357+
1358+
const (
1359+
crateSchema = `create schema TestTVPSchemaIdentity;`
1360+
1361+
dropSchema = `drop schema TestTVPSchemaIdentity;`
1362+
1363+
createTVP = `
1364+
CREATE TYPE TestTVPSchemaIdentity.exempleTVP AS TABLE
1365+
(
1366+
id int identity(1,1) not null,
1367+
message NVARCHAR(100)
1368+
)`
1369+
1370+
dropTVP = `DROP TYPE TestTVPSchemaIdentity.exempleTVP;`
1371+
1372+
procedureWithTVP = `
1373+
CREATE PROCEDURE ExecIdentityTVP
1374+
@param1 TestTVPSchemaIdentity.exempleTVP READONLY
1375+
AS
1376+
BEGIN
1377+
SET NOCOUNT ON;
1378+
SELECT * FROM @param1;
1379+
END;
1380+
`
1381+
1382+
dropProcedure = `drop PROCEDURE ExecIdentityTVP`
1383+
1384+
execTvp = `exec ExecIdentityTVP @param1;`
1385+
)
1386+
1387+
checkConnStr(t)
1388+
tl := testLogger{t: t}
1389+
defer tl.StopLogging()
1390+
SetLogger(&tl)
1391+
1392+
p := makeConnStr(t).String()
1393+
conn, err := sql.Open("sqlserver", p)
1394+
if err != nil {
1395+
log.Fatal("Open connection failed:", err.Error())
1396+
}
1397+
defer conn.Close()
1398+
1399+
_, err = conn.Exec(crateSchema)
1400+
if err != nil {
1401+
t.Fatal(err)
1402+
return
1403+
}
1404+
defer conn.Exec(dropSchema)
1405+
1406+
_, err = conn.Exec(createTVP)
1407+
if err != nil {
1408+
t.Fatal(err)
1409+
return
1410+
}
1411+
defer conn.Exec(dropTVP)
1412+
1413+
_, err = conn.Exec(procedureWithTVP)
1414+
if err != nil {
1415+
t.Fatal(err)
1416+
return
1417+
}
1418+
defer conn.Exec(dropProcedure)
1419+
1420+
exempleData := []TvpIdentityExample{
1421+
{
1422+
Message: "Hello",
1423+
},
1424+
{
1425+
Message: "World",
1426+
},
1427+
{
1428+
Message: "TVP",
1429+
},
1430+
}
1431+
1432+
tvpType := TVP{
1433+
TypeName: "TestTVPSchemaIdentity.exempleTVP",
1434+
Value: exempleData,
1435+
}
1436+
1437+
rows, err := conn.Query(execTvp,
1438+
sql.Named("param1", tvpType),
1439+
)
1440+
if err != nil {
1441+
t.Fatal(err)
1442+
}
1443+
defer rows.Close()
1444+
1445+
tvpResult := make([]TvpIdentityExample, 0)
1446+
for rows.Next() {
1447+
tvpExemple := TvpIdentityExample{}
1448+
err = rows.Scan(&tvpExemple.ID, &tvpExemple.Message)
1449+
if err != nil {
1450+
t.Fatal(err)
1451+
}
1452+
tvpResult = append(tvpResult, tvpExemple)
1453+
}
1454+
1455+
expectData := []TvpIdentityExample{
1456+
{
1457+
ID: 1,
1458+
Message: "Hello",
1459+
},
1460+
{
1461+
ID: 2,
1462+
Message: "World",
1463+
},
1464+
{
1465+
ID: 3,
1466+
Message: "TVP",
1467+
},
1468+
}
1469+
1470+
if len(expectData) != len(tvpResult) {
1471+
t.Fatal("TestTVPIdentity have to be len")
1472+
}
1473+
if !reflect.DeepEqual(expectData, tvpResult) {
1474+
t.Fatal("TestTVPIdentity have to be same")
1475+
}
1476+
}

types.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ const _PLP_TERMINATOR = 0x00000000
7979
const _TVP_END_TOKEN = 0x00
8080
const _TVP_ROW_TOKEN = 0x01
8181

82+
// TVP_COLMETADATA definition
83+
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/0dfc5367-a388-4c92-9ba4-4d28e775acbc
84+
const (
85+
fDefault = 0x200
86+
)
87+
8288
// TYPE_INFO rule
8389
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
8490
type typeInfo struct {
@@ -1353,12 +1359,13 @@ func makeGoLangTypeName(ti typeInfo) string {
13531359
// not a variable length type ok should return false.
13541360
// If length is not limited other than system limits, it should return math.MaxInt64.
13551361
// The following are examples of returned values for various types:
1356-
// TEXT (math.MaxInt64, true)
1357-
// varchar(10) (10, true)
1358-
// nvarchar(10) (10, true)
1359-
// decimal (0, false)
1360-
// int (0, false)
1361-
// bytea(30) (30, true)
1362+
//
1363+
// TEXT (math.MaxInt64, true)
1364+
// varchar(10) (10, true)
1365+
// nvarchar(10) (10, true)
1366+
// decimal (0, false)
1367+
// int (0, false)
1368+
// bytea(30) (30, true)
13621369
func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
13631370
switch ti.TypeId {
13641371
case typeInt1:
@@ -1476,12 +1483,13 @@ func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
14761483
// not a variable length type ok should return false.
14771484
// If length is not limited other than system limits, it should return math.MaxInt64.
14781485
// The following are examples of returned values for various types:
1479-
// TEXT (math.MaxInt64, true)
1480-
// varchar(10) (10, true)
1481-
// nvarchar(10) (10, true)
1482-
// decimal (0, false)
1483-
// int (0, false)
1484-
// bytea(30) (30, true)
1486+
//
1487+
// TEXT (math.MaxInt64, true)
1488+
// varchar(10) (10, true)
1489+
// nvarchar(10) (10, true)
1490+
// decimal (0, false)
1491+
// int (0, false)
1492+
// bytea(30) (30, true)
14851493
func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) {
14861494
switch ti.TypeId {
14871495
case typeInt1:

0 commit comments

Comments
 (0)