diff --git a/document_test.go b/document_test.go index 1978d971..fb60798b 100644 --- a/document_test.go +++ b/document_test.go @@ -99,6 +99,15 @@ func TestLoadDocument_Empty(t *testing.T) { assert.Error(t, err) } +func TestLoadDocument_BareMergeNodeReturnsError(t *testing.T) { + assert.NotPanics(t, func() { + doc, err := NewDocument([]byte("<<")) + assert.Nil(t, doc) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode YAML to JSON") + }) +} + func TestLoadDocument_Simple_V3(t *testing.T) { yml := `openapi: 3.0.1` doc, err := NewDocument([]byte(yml)) diff --git a/utils/utils.go b/utils/utils.go index 09d5eb43..6fefbb8d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -321,14 +321,26 @@ func ExtractValueFromInterfaceMap(name string, raw interface{}) interface{} { return nil } +// leadingMergeContent unwraps a leading YAML merge key when it has a corresponding value node. +// Malformed YAML can produce a bare `<<` node with no value; in that case we leave the original +// node slice intact and let higher-level validation return an error instead of panicking. +func leadingMergeContent(nodes []*yaml.Node) []*yaml.Node { + if len(nodes) < 2 || nodes[0] == nil || nodes[0].Tag != "!!merge" { + return nodes + } + merged := NodeAlias(nodes[1]) + if merged == nil { + return nodes + } + return merged.Content +} + // FindFirstKeyNode will locate the first key and value yaml.Node based on a key. func FindFirstKeyNode(key string, nodes []*yaml.Node, depth int) (keyNode *yaml.Node, valueNode *yaml.Node) { if depth > 40 { return nil, nil } - if nodes != nil && len(nodes) > 0 && nodes[0].Tag == "!!merge" { - nodes = NodeAlias(nodes[1]).Content - } + nodes = leadingMergeContent(nodes) for i, v := range nodes { if key != "" && key == v.Value { if i+1 >= len(nodes) { @@ -366,9 +378,7 @@ type KeyNodeSearch struct { // FindKeyNodeTop is a non-recursive search of top level nodes for a key, will not look at content. // Returns the key and value func FindKeyNodeTop(key string, nodes []*yaml.Node) (keyNode *yaml.Node, valueNode *yaml.Node) { - if nodes != nil && len(nodes) > 0 && nodes[0].Tag == "!!merge" { - nodes = NodeAlias(nodes[1]).Content - } + nodes = leadingMergeContent(nodes) for i := 0; i < len(nodes); i++ { v := nodes[i] if i%2 != 0 { @@ -387,9 +397,7 @@ func FindKeyNodeTop(key string, nodes []*yaml.Node) (keyNode *yaml.Node, valueNo // FindKeyNode is a non-recursive search of a *yaml.Node Content for a child node with a key. // Returns the key and value func FindKeyNode(key string, nodes []*yaml.Node) (keyNode *yaml.Node, valueNode *yaml.Node) { - if nodes != nil && len(nodes) > 0 && nodes[0].Tag == "!!merge" { - nodes = NodeAlias(nodes[1]).Content - } + nodes = leadingMergeContent(nodes) for i, v := range nodes { if i%2 == 0 && key == v.Value { if len(nodes) <= i+1 { @@ -419,9 +427,7 @@ func FindKeyNode(key string, nodes []*yaml.Node) (keyNode *yaml.Node, valueNode // generally different things are required from different node trees, so depending on what this function is looking at // it will return different things. func FindKeyNodeFull(key string, nodes []*yaml.Node) (keyNode *yaml.Node, labelNode *yaml.Node, valueNode *yaml.Node) { - if nodes != nil && len(nodes) > 0 && nodes[0].Tag == "!!merge" { - nodes = NodeAlias(nodes[1]).Content - } + nodes = leadingMergeContent(nodes) for i := 0; i < len(nodes); i++ { if i%2 == 0 && key == nodes[i].Value { if i+1 >= len(nodes) { @@ -460,9 +466,7 @@ func FindKeyNodeFull(key string, nodes []*yaml.Node) (keyNode *yaml.Node, labelN // FindKeyNodeFullTop is an overloaded version of FindKeyNodeFull. This version only looks at the top // level of the node and not the children. func FindKeyNodeFullTop(key string, nodes []*yaml.Node) (keyNode *yaml.Node, labelNode *yaml.Node, valueNode *yaml.Node) { - if nodes != nil && len(nodes) >= 0 && nodes[0].Tag == "!!merge" { - nodes = NodeAlias(nodes[1]).Content - } + nodes = leadingMergeContent(nodes) for i := 0; i < len(nodes); i++ { v := nodes[i] if i%2 == 0 { @@ -479,6 +483,9 @@ func FindKeyNodeFullTop(key string, nodes []*yaml.Node) (keyNode *yaml.Node, lab continue } if i%2 == 0 && key == nodes[i].Value { + if i+1 >= len(nodes) { + return NodeAlias(nodes[i]), NodeAlias(nodes[i]), NodeAlias(nodes[i]) + } return NodeAlias(nodes[i]), NodeAlias(nodes[i]), NodeAlias(nodes[i+1]) // next node is what we need. } } @@ -1236,7 +1243,7 @@ func CheckForMergeNodes(node *yaml.Node) { for i := 0; i < total; i++ { mn := node.Content[i] if i%2 == 0 { - if mn.Tag == "!!merge" { + if mn.Tag == "!!merge" && i+1 < len(node.Content) { an := node.Content[i+1].Alias if an != nil { node.Content = append(node.Content, an.Content...) // append the merged nodes diff --git a/utils/utils_test.go b/utils/utils_test.go index 19fc4c90..717f37a0 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -10,6 +10,7 @@ import ( "github.com/pb33f/jsonpath/pkg/jsonpath" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.yaml.in/yaml/v4" ) @@ -1572,6 +1573,94 @@ x-b: assert.Equal(t, "a nice string", v.Value) } +func TestFindKeyNodeHelpers_BareMergeNodeDoesNotPanic(t *testing.T) { + var rootNode yaml.Node + err := yaml.Unmarshal([]byte("<<"), &rootNode) + assert.NoError(t, err) + assert.Len(t, rootNode.Content, 1) + assert.Equal(t, "!!merge", rootNode.Content[0].Tag) + + assert.NotPanics(t, func() { + k, v := FindFirstKeyNode("openapi", rootNode.Content, 0) + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = FindKeyNodeTop("openapi", rootNode.Content) + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = FindKeyNode("openapi", rootNode.Content) + assert.Nil(t, k) + assert.Nil(t, v) + + k, l, v := FindKeyNodeFull("openapi", rootNode.Content) + assert.Nil(t, k) + assert.Nil(t, l) + assert.Nil(t, v) + + k, l, v = FindKeyNodeFullTop("openapi", rootNode.Content) + assert.Nil(t, k) + assert.Nil(t, l) + assert.Nil(t, v) + }) +} + +func TestFindKeyNodeHelpers_MergeKeyWithoutAliasDoesNotPanic(t *testing.T) { + nodes := []*yaml.Node{ + {Kind: yaml.ScalarNode, Tag: "!!merge", Value: "<<"}, + {Kind: yaml.ScalarNode, Tag: "!!str", Value: "not-an-alias"}, + } + + assert.NotPanics(t, func() { + k, v := FindFirstKeyNode("openapi", nodes, 0) + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = FindKeyNodeTop("openapi", nodes) + assert.Nil(t, k) + assert.Nil(t, v) + + k, v = FindKeyNode("openapi", nodes) + assert.Nil(t, k) + assert.Nil(t, v) + + k, l, v := FindKeyNodeFull("openapi", nodes) + assert.Nil(t, k) + assert.Nil(t, l) + assert.Nil(t, v) + + k, l, v = FindKeyNodeFullTop("openapi", nodes) + assert.Nil(t, k) + assert.Nil(t, l) + assert.Nil(t, v) + }) +} + +func TestLeadingMergeContent_NilMergeValueReturnsOriginalNodes(t *testing.T) { + nodes := []*yaml.Node{ + {Kind: yaml.ScalarNode, Tag: "!!merge", Value: "<<"}, + nil, + } + + result := leadingMergeContent(nodes) + assert.Equal(t, nodes, result) + assert.Nil(t, result[1]) +} + +func TestFindKeyNodeFullTop_OddLengthNodesReturnsKeyNode(t *testing.T) { + nodes := []*yaml.Node{ + {Kind: yaml.ScalarNode, Tag: "!!str", Value: "openapi"}, + } + + k, l, v := FindKeyNodeFullTop("openapi", nodes) + require.NotNil(t, k) + require.NotNil(t, l) + require.NotNil(t, v) + assert.Same(t, nodes[0], k) + assert.Same(t, nodes[0], l) + assert.Same(t, nodes[0], v) +} + func TestNodeMerge(t *testing.T) { yml := []byte(`openapi: 3.0.3 any-thing: &anchorH @@ -1599,6 +1688,20 @@ func TestNodeMerge_NoNodes(t *testing.T) { assert.Nil(t, n) } +func TestCheckForMergeNodes_BareMergeNodeDoesNotPanic(t *testing.T) { + node := &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + {Kind: yaml.ScalarNode, Tag: "!!merge", Value: "<<"}, + }, + } + + assert.NotPanics(t, func() { + CheckForMergeNodes(node) + }) + assert.Len(t, node.Content, 1) +} + func TestIsNodeNull(t *testing.T) { n := &yaml.Node{ Tag: "!!null",