Skip to content

Commit 4f8b610

Browse files
committed
Add Hinter interface to allow types to set TypeHint
1 parent 1b3441a commit 4f8b610

4 files changed

Lines changed: 68 additions & 13 deletions

File tree

conn.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,17 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
7979
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
8080
return executeStatement(ctx, c.config, query, c.transactionID, args...)
8181
}
82+
83+
// CheckNamedValue allows some types to get passed down to the
84+
// executor so the TypeHint can be set
85+
func (c *Conn) CheckNamedValue(v *driver.NamedValue) error {
86+
switch v.Value.(type) {
87+
case uuid.UUID:
88+
return nil
89+
}
90+
91+
if _, ok := v.Value.(Hinter); ok {
92+
return nil
93+
}
94+
return driver.ErrSkip
95+
}

executor.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,14 @@ func executeStatement(ctx context.Context, config *config, query, transactionID
122122
name = prefix + strconv.Itoa(arg.Ordinal)
123123
}
124124

125+
val, hint, err := asField(arg.Value)
126+
if err != nil {
127+
return nil, err
128+
}
125129
param := rdsdataservice.SqlParameter{
126-
Name: aws.String(name),
127-
Value: asField(arg.Value),
130+
Name: aws.String(name),
131+
Value: val,
132+
TypeHint: hint,
128133
}
129134

130135
input.Parameters = append(input.Parameters, &param)

go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ module github.com/savaki/dapi
22

33
go 1.14
44

5-
require github.com/aws/aws-sdk-go v1.33.13
5+
require (
6+
github.com/aws/aws-sdk-go v1.33.13
7+
github.com/gofrs/uuid v4.0.0+incompatible // indirect
8+
)

stmt.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@ package dapi
1717
import (
1818
"context"
1919
"database/sql/driver"
20+
"errors"
21+
"time"
22+
2023
"github.com/aws/aws-sdk-go/aws"
2124
"github.com/aws/aws-sdk-go/service/rdsdataservice"
22-
"time"
25+
"github.com/gofrs/uuid"
2326
)
2427

28+
var ErrInvalidField = errors.New("invalid field")
29+
2530
type Stmt struct {
2631
ctx context.Context
2732
config *config
@@ -58,7 +63,7 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
5863
}
5964

6065
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
61-
panic("implement me: QueryContext (Stmt)")
66+
return executeStatement(ctx, s.config, s.query, "", args...)
6267
}
6368

6469
func newStmt(ctx context.Context, config *config, query string) *Stmt {
@@ -69,23 +74,51 @@ func newStmt(ctx context.Context, config *config, query string) *Stmt {
6974
}
7075
}
7176

72-
func asField(value driver.Value) *rdsdataservice.Field {
77+
// If a type implements Hinter it can provide a type hint to the data-api directly
78+
type Hinter interface {
79+
TypeHint() string
80+
driver.Valuer
81+
}
82+
83+
func asField(value driver.Value) (*rdsdataservice.Field, *string, error) {
84+
var hint *string
85+
if v, ok := value.(Hinter); ok {
86+
hint = aws.String(v.TypeHint())
87+
} else {
88+
switch value.(type) {
89+
case time.Time:
90+
hint = aws.String("TIMESTAMP")
91+
case uuid.UUID:
92+
hint = aws.String("UUID")
93+
}
94+
}
95+
if v, ok := value.(driver.Valuer); ok {
96+
var err error
97+
value, err = v.Value()
98+
if err != nil {
99+
return nil, hint, err
100+
}
101+
}
102+
73103
switch v := value.(type) {
74104
case int64:
75-
return &rdsdataservice.Field{LongValue: aws.Int64(v)}
105+
return &rdsdataservice.Field{LongValue: aws.Int64(v)}, hint, nil
76106
case float64:
77-
return &rdsdataservice.Field{DoubleValue: aws.Float64(v)}
107+
return &rdsdataservice.Field{DoubleValue: aws.Float64(v)}, hint, nil
78108
case bool:
79-
return &rdsdataservice.Field{BooleanValue: aws.Bool(v)}
109+
return &rdsdataservice.Field{BooleanValue: aws.Bool(v)}, hint, nil
80110
case []byte:
81-
return &rdsdataservice.Field{BlobValue: v}
111+
return &rdsdataservice.Field{BlobValue: v}, hint, nil
82112
case string:
83-
return &rdsdataservice.Field{StringValue: aws.String(v)}
113+
return &rdsdataservice.Field{StringValue: aws.String(v)}, hint, nil
84114
case time.Time:
85115
s := v.Format("2006-01-02 15:04:05")
86-
return &rdsdataservice.Field{StringValue: aws.String(s)}
116+
return &rdsdataservice.Field{StringValue: aws.String(s)}, hint, nil
87117
default:
88-
return &rdsdataservice.Field{IsNull: aws.Bool(true)}
118+
if v == nil {
119+
return &rdsdataservice.Field{IsNull: aws.Bool(true)}, hint, nil
120+
}
121+
return nil, hint, ErrInvalidField
89122
}
90123
}
91124

0 commit comments

Comments
 (0)