Skip to content

Commit 6e04cf2

Browse files
authored
Merge pull request #809 from cloudskiff/fix_kms_key_crash
Fix kms_key crash
2 parents 8857bd1 + 410e3df commit 6e04cf2

File tree

3 files changed

+157
-24
lines changed

3 files changed

+157
-24
lines changed

pkg/remote/aws/repository/kms_repository.go

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package repository
22

33
import (
4+
"fmt"
45
"strings"
6+
"sync"
57

8+
"github.com/aws/aws-sdk-go/aws"
69
"github.com/aws/aws-sdk-go/aws/session"
710
"github.com/aws/aws-sdk-go/service/kms"
811
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
912
"github.com/cloudskiff/driftctl/pkg/remote/cache"
13+
"github.com/sirupsen/logrus"
1014
)
1115

1216
type KMSRepository interface {
@@ -15,14 +19,16 @@ type KMSRepository interface {
1519
}
1620

1721
type kmsRepository struct {
18-
client kmsiface.KMSAPI
19-
cache cache.Cache
22+
client kmsiface.KMSAPI
23+
cache cache.Cache
24+
describeKeyLock *sync.Mutex
2025
}
2126

2227
func NewKMSRepository(session *session.Session, c cache.Cache) *kmsRepository {
2328
return &kmsRepository{
2429
kms.New(session),
2530
c,
31+
&sync.Mutex{},
2632
}
2733
}
2834

@@ -68,33 +74,73 @@ func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) {
6874
return nil, err
6975
}
7076

71-
result := r.filterAliases(aliases)
77+
result, err := r.filterAliases(aliases)
78+
if err != nil {
79+
return nil, err
80+
}
7281
r.cache.Put("kmsListAllAliases", result)
7382
return result, nil
7483
}
7584

85+
func (r *kmsRepository) describeKey(keyId *string) (*kms.DescribeKeyOutput, error) {
86+
var results interface{}
87+
// Since this method can be call in parallel, we should lock and unlock if we want to be sure to hit the cache
88+
r.describeKeyLock.Lock()
89+
defer r.describeKeyLock.Unlock()
90+
cacheKey := fmt.Sprintf("kmsDescribeKey-%s", *keyId)
91+
results = r.cache.Get(cacheKey)
92+
if results == nil {
93+
var err error
94+
results, err = r.client.DescribeKey(&kms.DescribeKeyInput{KeyId: keyId})
95+
if err != nil {
96+
return nil, err
97+
}
98+
r.cache.Put(cacheKey, results)
99+
}
100+
describeKey := results.(*kms.DescribeKeyOutput)
101+
if aws.StringValue(describeKey.KeyMetadata.KeyState) == kms.KeyStatePendingDeletion {
102+
return nil, nil
103+
}
104+
return describeKey, nil
105+
}
106+
76107
func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntry, error) {
77108
var customerKeys []*kms.KeyListEntry
78109
for _, key := range keys {
79-
k, err := r.client.DescribeKey(&kms.DescribeKeyInput{
80-
KeyId: key.KeyId,
81-
})
110+
k, err := r.describeKey(key.KeyId)
82111
if err != nil {
83112
return nil, err
84113
}
114+
if k == nil {
115+
logrus.WithFields(logrus.Fields{
116+
"id": *key.KeyId,
117+
}).Debug("Ignored kms key from listing since it is pending from deletion")
118+
continue
119+
}
85120
if k.KeyMetadata.KeyManager != nil && *k.KeyMetadata.KeyManager != "AWS" {
86121
customerKeys = append(customerKeys, key)
87122
}
88123
}
89124
return customerKeys, nil
90125
}
91126

92-
func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) []*kms.AliasListEntry {
127+
func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) ([]*kms.AliasListEntry, error) {
93128
var customerAliases []*kms.AliasListEntry
94129
for _, alias := range aliases {
95130
if alias.AliasName != nil && !strings.HasPrefix(*alias.AliasName, "alias/aws/") {
131+
k, err := r.describeKey(alias.TargetKeyId)
132+
if err != nil {
133+
return nil, err
134+
}
135+
if k == nil {
136+
logrus.WithFields(logrus.Fields{
137+
"id": *alias.TargetKeyId,
138+
"alias": *alias.AliasName,
139+
}).Debug("Ignored kms key alias from listing since it is linked to a pending from deletion key")
140+
continue
141+
}
96142
customerAliases = append(customerAliases, alias)
97143
}
98144
}
99-
return customerAliases
145+
return customerAliases, nil
100146
}

pkg/remote/aws/repository/kms_repository_test.go

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package repository
22

33
import (
44
"strings"
5+
"sync"
56
"testing"
67

78
"github.com/aws/aws-sdk-go/aws"
@@ -21,6 +22,45 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
2122
want []*kms.KeyListEntry
2223
wantErr error
2324
}{
25+
{
26+
name: "List only enabled keys",
27+
mocks: func(client *awstest.MockFakeKMS) {
28+
client.On("ListKeysPages",
29+
&kms.ListKeysInput{},
30+
mock.MatchedBy(func(callback func(res *kms.ListKeysOutput, lastPage bool) bool) bool {
31+
callback(&kms.ListKeysOutput{
32+
Keys: []*kms.KeyListEntry{
33+
{KeyId: aws.String("1")},
34+
{KeyId: aws.String("2")},
35+
},
36+
}, true)
37+
return true
38+
})).Return(nil).Once()
39+
client.On("DescribeKey",
40+
&kms.DescribeKeyInput{
41+
KeyId: aws.String("1"),
42+
}).Return(&kms.DescribeKeyOutput{
43+
KeyMetadata: &kms.KeyMetadata{
44+
KeyId: aws.String("1"),
45+
KeyManager: aws.String("CUSTOMER"),
46+
KeyState: aws.String(kms.KeyStateEnabled),
47+
},
48+
}, nil).Once()
49+
client.On("DescribeKey",
50+
&kms.DescribeKeyInput{
51+
KeyId: aws.String("2"),
52+
}).Return(&kms.DescribeKeyOutput{
53+
KeyMetadata: &kms.KeyMetadata{
54+
KeyId: aws.String("2"),
55+
KeyManager: aws.String("CUSTOMER"),
56+
KeyState: aws.String(kms.KeyStatePendingDeletion),
57+
},
58+
}, nil).Once()
59+
},
60+
want: []*kms.KeyListEntry{
61+
{KeyId: aws.String("1")},
62+
},
63+
},
2464
{
2565
name: "List only customer keys",
2666
mocks: func(client *awstest.MockFakeKMS) {
@@ -43,6 +83,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
4383
KeyMetadata: &kms.KeyMetadata{
4484
KeyId: aws.String("1"),
4585
KeyManager: aws.String("CUSTOMER"),
86+
KeyState: aws.String(kms.KeyStateEnabled),
4687
},
4788
}, nil).Once()
4889
client.On("DescribeKey",
@@ -52,6 +93,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
5293
KeyMetadata: &kms.KeyMetadata{
5394
KeyId: aws.String("2"),
5495
KeyManager: aws.String("AWS"),
96+
KeyState: aws.String(kms.KeyStateEnabled),
5597
},
5698
}, nil).Once()
5799
client.On("DescribeKey",
@@ -61,6 +103,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
61103
KeyMetadata: &kms.KeyMetadata{
62104
KeyId: aws.String("3"),
63105
KeyManager: aws.String("AWS"),
106+
KeyState: aws.String(kms.KeyStateEnabled),
64107
},
65108
}, nil).Once()
66109
},
@@ -75,8 +118,9 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) {
75118
client := awstest.MockFakeKMS{}
76119
tt.mocks(&client)
77120
r := &kmsRepository{
78-
client: &client,
79-
cache: store,
121+
client: &client,
122+
cache: store,
123+
describeKeyLock: &sync.Mutex{},
80124
}
81125
got, err := r.ListAllKeys()
82126
assert.Equal(t, tt.wantErr, err)
@@ -108,6 +152,35 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) {
108152
want []*kms.AliasListEntry
109153
wantErr error
110154
}{
155+
{
156+
name: "List only aliases for enabled keys",
157+
mocks: func(client *awstest.MockFakeKMS) {
158+
client.On("ListAliasesPages",
159+
&kms.ListAliasesInput{},
160+
mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool {
161+
callback(&kms.ListAliasesOutput{
162+
Aliases: []*kms.AliasListEntry{
163+
{AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")},
164+
{AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")},
165+
},
166+
}, true)
167+
return true
168+
})).Return(nil).Once()
169+
client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-1")}).Return(&kms.DescribeKeyOutput{
170+
KeyMetadata: &kms.KeyMetadata{
171+
KeyState: aws.String(kms.KeyStatePendingDeletion),
172+
},
173+
}, nil)
174+
client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-2")}).Return(&kms.DescribeKeyOutput{
175+
KeyMetadata: &kms.KeyMetadata{
176+
KeyState: aws.String(kms.KeyStateEnabled),
177+
},
178+
}, nil)
179+
},
180+
want: []*kms.AliasListEntry{
181+
{AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")},
182+
},
183+
},
111184
{
112185
name: "List only customer aliases",
113186
mocks: func(client *awstest.MockFakeKMS) {
@@ -116,24 +189,29 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) {
116189
mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool {
117190
callback(&kms.ListAliasesOutput{
118191
Aliases: []*kms.AliasListEntry{
119-
{AliasName: aws.String("alias/1")},
120-
{AliasName: aws.String("alias/foo/2")},
121-
{AliasName: aws.String("alias/aw/3")},
122-
{AliasName: aws.String("alias/aws/4")},
123-
{AliasName: aws.String("alias/aws/5")},
124-
{AliasName: aws.String("alias/awss/6")},
125-
{AliasName: aws.String("alias/aws7")},
192+
{AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")},
193+
{AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")},
194+
{AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")},
195+
{AliasName: aws.String("alias/aws/4"), TargetKeyId: aws.String("key-id-4")},
196+
{AliasName: aws.String("alias/aws/5"), TargetKeyId: aws.String("key-id-5")},
197+
{AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")},
198+
{AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")},
126199
},
127200
}, true)
128201
return true
129202
})).Return(nil).Once()
203+
client.On("DescribeKey", mock.Anything).Return(&kms.DescribeKeyOutput{
204+
KeyMetadata: &kms.KeyMetadata{
205+
KeyState: aws.String(kms.KeyStateEnabled),
206+
},
207+
}, nil)
130208
},
131209
want: []*kms.AliasListEntry{
132-
{AliasName: aws.String("alias/1")},
133-
{AliasName: aws.String("alias/foo/2")},
134-
{AliasName: aws.String("alias/aw/3")},
135-
{AliasName: aws.String("alias/awss/6")},
136-
{AliasName: aws.String("alias/aws7")},
210+
{AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")},
211+
{AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")},
212+
{AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")},
213+
{AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")},
214+
{AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")},
137215
},
138216
},
139217
}
@@ -143,8 +221,9 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) {
143221
client := awstest.MockFakeKMS{}
144222
tt.mocks(&client)
145223
r := &kmsRepository{
146-
client: &client,
147-
cache: store,
224+
client: &client,
225+
cache: store,
226+
describeKeyLock: &sync.Mutex{},
148227
}
149228
got, err := r.ListAllAliases()
150229
assert.Equal(t, tt.wantErr, err)

pkg/remote/common/details_fetcher.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package common
33
import (
44
"github.com/cloudskiff/driftctl/pkg/resource"
55
"github.com/cloudskiff/driftctl/pkg/terraform"
6+
"github.com/sirupsen/logrus"
67
)
78

89
type DetailsFetcher interface {
@@ -31,6 +32,13 @@ func (f *GenericDetailsFetcher) ReadDetails(res resource.Resource) (resource.Res
3132
if err != nil {
3233
return nil, err
3334
}
35+
if ctyVal.IsNull() {
36+
logrus.WithFields(logrus.Fields{
37+
"type": f.resType,
38+
"id": res.TerraformId(),
39+
}).Debug("Got null while reading resource details")
40+
return nil, nil
41+
}
3442
deserializedRes, err := f.deserializer.DeserializeOne(string(f.resType), *ctyVal)
3543
if err != nil {
3644
return nil, err

0 commit comments

Comments
 (0)