Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,117 @@ func TestParameter_Inference(t *testing.T) {
})
}

func TestParameter_BigInt(t *testing.T) {
t.Run("Should infer int64 as BIGINT", func(t *testing.T) {
maxInt64 := int64(9223372036854775807)
values := []driver.NamedValue{
{Value: maxInt64},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "BIGINT", *parameters[0].Type)
require.Equal(t, "9223372036854775807", *parameters[0].Value.StringValue)
})

t.Run("Should infer uint64 as BIGINT", func(t *testing.T) {
largeUint64 := uint64(0x123456789ABCDEF0)
values := []driver.NamedValue{
{Value: largeUint64},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "BIGINT", *parameters[0].Type)
require.Equal(t, "1311768467463790320", *parameters[0].Value.StringValue)
})

t.Run("Should infer negative int64 as BIGINT", func(t *testing.T) {
minInt64 := int64(-9223372036854775808)
values := []driver.NamedValue{
{Value: minInt64},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "BIGINT", *parameters[0].Type)
require.Equal(t, "-9223372036854775808", *parameters[0].Value.StringValue)
})

t.Run("Should handle explicit BigInt Parameter with non-string value", func(t *testing.T) {
values := []driver.NamedValue{
{Value: Parameter{Type: SqlBigInt, Value: int64(12345)}},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "BIGINT", *parameters[0].Type)
require.Equal(t, "12345", *parameters[0].Value.StringValue)
})

t.Run("Should preserve int32 as INTEGER", func(t *testing.T) {
values := []driver.NamedValue{
{Value: int32(2147483647)},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "INTEGER", *parameters[0].Type)
require.Equal(t, "2147483647", *parameters[0].Value.StringValue)
})
}

func TestParameter_Float(t *testing.T) {
t.Run("Should infer float64 as DOUBLE", func(t *testing.T) {
value := float64(3.141592653589793)
values := []driver.NamedValue{
{Value: value},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "DOUBLE", *parameters[0].Type)
require.Equal(t, "3.141592653589793", *parameters[0].Value.StringValue)
})

t.Run("Should infer float32 as FLOAT", func(t *testing.T) {
value := float32(3.14)
values := []driver.NamedValue{
{Value: value},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "FLOAT", *parameters[0].Type)
require.Equal(t, "3.14", *parameters[0].Value.StringValue)
})

t.Run("Should handle large float64 values", func(t *testing.T) {
// Value beyond float32 range
value := float64(1e200)
values := []driver.NamedValue{
{Value: value},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "DOUBLE", *parameters[0].Type)
})

t.Run("Should handle small float64 values", func(t *testing.T) {
// Value below float32 precision
value := float64(1e-300)
values := []driver.NamedValue{
{Value: value},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "DOUBLE", *parameters[0].Type)
})

t.Run("Should handle explicit Double Parameter with non-string value", func(t *testing.T) {
values := []driver.NamedValue{
{Value: Parameter{Type: SqlDouble, Value: float64(3.14159)}},
}
parameters, err := convertNamedValuesToSparkParams(values)
require.NoError(t, err)
require.Equal(t, "DOUBLE", *parameters[0].Type)
require.Equal(t, "3.14159", *parameters[0].Value.StringValue)
})
}

func TestParameters_ConvertToSpark(t *testing.T) {
t.Run("Should convert names parameters", func(t *testing.T) {
values := [2]driver.NamedValue{
Expand Down
20 changes: 13 additions & 7 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,17 @@ func inferType(param *Parameter) {
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
param.Value = strconv.FormatInt(value, 10)
param.Type = SqlBigInt
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
param.Value = strconv.FormatUint(value, 10)
param.Type = SqlBigInt
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = SqlFloat
case float64:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 64)
param.Type = SqlFloat
param.Value = strconv.FormatFloat(value, 'f', -1, 64)
param.Type = SqlDouble
case time.Time:
param.Value = value.Format(time.RFC3339Nano)
param.Type = SqlTimestamp
Expand Down Expand Up @@ -179,7 +179,13 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) ([]*cli_service
if sqlParam.Type == SqlVoid {
sparkValue = nil
} else {
stringValue := sqlParam.Value.(string)
var stringValue string
switch v := sqlParam.Value.(type) {
case string:
stringValue = v
Comment thread
vikrantpuppala marked this conversation as resolved.
default:
stringValue = fmt.Sprintf("%v", sqlParam.Value)
}
sparkValue = &cli_service.TSparkParameterValue{StringValue: &stringValue}
}

Expand Down
Loading