Skip to content

Commit 5fbc08f

Browse files
committed
fix: Update the validator criteria to allow wifi config is empty when DHCP is false
1 parent e16afa8 commit 5fbc08f

4 files changed

Lines changed: 299 additions & 7 deletions

File tree

internal/controller/httpapi/v1/devicemanagement_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,9 @@ func TestDeviceManagement(t *testing.T) {
394394
"isTrusted": true,
395395
},
396396
mock: func(m *mocks.MockDeviceManagementFeature) {
397-
m.EXPECT().AddCertificate(
398-
context.Background(),
399-
"valid-guid",
400-
gomock.Any(),
401-
).Return("", ErrGeneral).AnyTimes()
397+
// No mock needed - validation will fail before reaching the service
402398
},
403-
expectedCode: http.StatusInternalServerError,
399+
expectedCode: http.StatusBadRequest,
404400
response: nil,
405401
},
406402
{
@@ -465,6 +461,7 @@ func TestDeviceManagement(t *testing.T) {
465461
if tc.method == http.MethodPost || tc.method == http.MethodPatch || tc.method == http.MethodDelete {
466462
reqBody, _ := json.Marshal(tc.requestBody)
467463
req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, bytes.NewBuffer(reqBody))
464+
req.Header.Set("Content-Type", "application/json")
468465
} else {
469466
req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, http.NoBody)
470467
}

internal/controller/httpapi/v1/profiles.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ func NewProfileRoutes(handler *gin.RouterGroup, t profiles.Feature, l logger.Int
2727
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
2828
_ = v.RegisterValidation("genpasswordwone", dto.ValidateAMTPassOrGenRan)
2929
_ = v.RegisterValidation("ciraortls", dto.ValidateCIRAOrTLS)
30+
_ = v.RegisterValidation("wifidhcp", dto.ValidateWiFiDHCP)
3031
}
3132
}
3233

internal/controller/httpapi/v1/profiles_test.go

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@ import (
66
"encoding/json"
77
"net/http"
88
"net/http/httptest"
9+
"reflect"
10+
"sync"
911
"testing"
1012

1113
"github.com/gin-gonic/gin"
14+
"github.com/gin-gonic/gin/binding"
15+
"github.com/go-playground/validator/v10"
1216
"github.com/stretchr/testify/require"
1317
gomock "go.uber.org/mock/gomock"
1418

@@ -18,9 +22,92 @@ import (
1822
"github.com/device-management-toolkit/console/pkg/logger"
1923
)
2024

25+
var (
26+
validatorOnce sync.Once
27+
sharedValidator *defaultValidator
28+
validatorInitLock sync.Mutex
29+
routeRegistrationMu sync.Mutex // Protects NewProfileRoutes calls to prevent concurrent validator registration
30+
)
31+
32+
// defaultValidator implements the gin binding.StructValidator interface
33+
type defaultValidator struct {
34+
once sync.Once
35+
validate *validator.Validate
36+
}
37+
38+
func (v *defaultValidator) ValidateStruct(obj any) error {
39+
if obj == nil {
40+
return nil
41+
}
42+
43+
value := reflect.ValueOf(obj)
44+
switch value.Kind() {
45+
case reflect.Ptr:
46+
return v.ValidateStruct(value.Elem().Interface())
47+
case reflect.Struct:
48+
return v.validateStruct(obj)
49+
case reflect.Slice, reflect.Array:
50+
count := value.Len()
51+
validateRet := make(binding.SliceValidationError, 0)
52+
53+
for i := 0; i < count; i++ {
54+
if err := v.ValidateStruct(value.Index(i).Interface()); err != nil {
55+
validateRet = append(validateRet, err)
56+
}
57+
}
58+
59+
if len(validateRet) == 0 {
60+
return nil
61+
}
62+
63+
return validateRet
64+
case reflect.Invalid, reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
65+
reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
66+
reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128,
67+
reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.String, reflect.UnsafePointer:
68+
return nil
69+
default:
70+
return nil
71+
}
72+
}
73+
74+
func (v *defaultValidator) validateStruct(obj any) error {
75+
v.lazyinit()
76+
77+
return v.validate.Struct(obj)
78+
}
79+
80+
func (v *defaultValidator) Engine() any {
81+
v.lazyinit()
82+
83+
return v.validate
84+
}
85+
86+
func (v *defaultValidator) lazyinit() {
87+
v.once.Do(func() {
88+
v.validate = validator.New()
89+
v.validate.SetTagName("binding")
90+
91+
// Register custom validators
92+
_ = v.validate.RegisterValidation("genpasswordwone", dto.ValidateAMTPassOrGenRan)
93+
_ = v.validate.RegisterValidation("ciraortls", dto.ValidateCIRAOrTLS)
94+
_ = v.validate.RegisterValidation("wifidhcp", dto.ValidateWiFiDHCP)
95+
})
96+
}
97+
2198
func profilesTest(t *testing.T) (*mocks.MockProfilesFeature, *gin.Engine) {
2299
t.Helper()
23100

101+
// Initialize shared validator once for all parallel tests
102+
validatorInitLock.Lock()
103+
validatorOnce.Do(func() {
104+
sharedValidator = &defaultValidator{}
105+
// Trigger lazy initialization to ensure validators are registered before any parallel access
106+
sharedValidator.lazyinit()
107+
binding.Validator = sharedValidator
108+
})
109+
validatorInitLock.Unlock()
110+
24111
mockCtl := gomock.NewController(t)
25112
defer mockCtl.Finish()
26113

@@ -30,7 +117,10 @@ func profilesTest(t *testing.T) (*mocks.MockProfilesFeature, *gin.Engine) {
30117
engine := gin.New()
31118
handler := engine.Group("/api/v1/admin")
32119

120+
// Serialize NewProfileRoutes calls to prevent concurrent validator registration
121+
routeRegistrationMu.Lock()
33122
NewProfileRoutes(handler, mockProfiles, log)
123+
routeRegistrationMu.Unlock()
34124

35125
return mockProfiles, engine
36126
}
@@ -277,6 +367,7 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct
277367
if tc.requestBody.ProfileName != "" {
278368
reqBody, _ := json.Marshal(tc.requestBody)
279369
req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, bytes.NewBuffer(reqBody))
370+
req.Header.Set("Content-Type", "application/json")
280371
} else {
281372
req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, http.NoBody)
282373
}
@@ -309,3 +400,194 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct
309400
})
310401
}
311402
}
403+
404+
func TestProfileValidation(t *testing.T) {
405+
t.Parallel()
406+
407+
tests := []struct {
408+
name string
409+
profile dto.Profile
410+
expectedCode int
411+
}{
412+
{
413+
name: "valid profile - CCM with CIRA",
414+
profile: dto.Profile{
415+
ProfileName: "test-profile",
416+
Activation: "ccmactivate",
417+
GenerateRandomPassword: true,
418+
GenerateRandomMEBxPassword: true,
419+
CIRAConfigName: stringPtr("cira-config"),
420+
DHCPEnabled: true,
421+
UserConsent: "All",
422+
TenantID: "tenant1",
423+
},
424+
expectedCode: http.StatusCreated,
425+
},
426+
{
427+
name: "valid profile - ACM with TLS",
428+
profile: dto.Profile{
429+
ProfileName: "test-profile",
430+
Activation: "acmactivate",
431+
GenerateRandomPassword: true,
432+
MEBXPassword: "P@ssw0rd123",
433+
GenerateRandomMEBxPassword: false,
434+
TLSMode: 1,
435+
TLSSigningAuthority: "SelfSigned",
436+
DHCPEnabled: true,
437+
UserConsent: "KVM",
438+
TenantID: "tenant1",
439+
},
440+
expectedCode: http.StatusCreated,
441+
},
442+
{
443+
name: "invalid - both CIRA and TLS",
444+
profile: dto.Profile{
445+
ProfileName: "test-profile",
446+
Activation: "ccmactivate",
447+
GenerateRandomPassword: true,
448+
GenerateRandomMEBxPassword: true,
449+
CIRAConfigName: stringPtr("cira-config"),
450+
TLSMode: 1,
451+
DHCPEnabled: true,
452+
UserConsent: "All",
453+
TenantID: "tenant1",
454+
},
455+
expectedCode: http.StatusBadRequest,
456+
},
457+
{
458+
name: "invalid - wifi configs without DHCP",
459+
profile: dto.Profile{
460+
ProfileName: "test-profile",
461+
Activation: "ccmactivate",
462+
GenerateRandomPassword: true,
463+
GenerateRandomMEBxPassword: true,
464+
DHCPEnabled: false,
465+
WiFiConfigs: []dto.ProfileWiFiConfigs{
466+
{ProfileName: "wifi1", Priority: 1},
467+
},
468+
UserConsent: "All",
469+
TenantID: "tenant1",
470+
},
471+
expectedCode: http.StatusBadRequest,
472+
},
473+
{
474+
name: "invalid - password set with genRandom true",
475+
profile: dto.Profile{
476+
ProfileName: "test-profile",
477+
Activation: "ccmactivate",
478+
AMTPassword: "P@ssw0rd123",
479+
GenerateRandomPassword: true,
480+
GenerateRandomMEBxPassword: true,
481+
DHCPEnabled: true,
482+
UserConsent: "All",
483+
TenantID: "tenant1",
484+
},
485+
expectedCode: http.StatusBadRequest,
486+
},
487+
{
488+
name: "invalid - invalid activation",
489+
profile: dto.Profile{
490+
ProfileName: "test-profile",
491+
Activation: "invalidactivation",
492+
GenerateRandomPassword: true,
493+
GenerateRandomMEBxPassword: true,
494+
UserConsent: "All",
495+
TenantID: "tenant1",
496+
},
497+
expectedCode: http.StatusBadRequest,
498+
},
499+
{
500+
name: "invalid - invalid TLS signing authority",
501+
profile: dto.Profile{
502+
ProfileName: "test-profile",
503+
Activation: "acmactivate",
504+
GenerateRandomPassword: true,
505+
GenerateRandomMEBxPassword: true,
506+
TLSMode: 1,
507+
TLSSigningAuthority: "InvalidAuthority",
508+
DHCPEnabled: true,
509+
UserConsent: "All",
510+
TenantID: "tenant1",
511+
},
512+
expectedCode: http.StatusBadRequest,
513+
},
514+
{
515+
name: "invalid - TLS mode out of range",
516+
profile: dto.Profile{
517+
ProfileName: "test-profile",
518+
Activation: "acmactivate",
519+
GenerateRandomPassword: true,
520+
GenerateRandomMEBxPassword: true,
521+
TLSMode: 5,
522+
TLSSigningAuthority: "SelfSigned",
523+
DHCPEnabled: true,
524+
UserConsent: "All",
525+
TenantID: "tenant1",
526+
},
527+
expectedCode: http.StatusBadRequest,
528+
},
529+
{
530+
name: "invalid - password too short",
531+
profile: dto.Profile{
532+
ProfileName: "test-profile",
533+
Activation: "acmactivate",
534+
AMTPassword: "short",
535+
GenerateRandomPassword: false,
536+
MEBXPassword: "P@ssw0rd123",
537+
GenerateRandomMEBxPassword: false,
538+
DHCPEnabled: true,
539+
UserConsent: "All",
540+
TenantID: "tenant1",
541+
},
542+
expectedCode: http.StatusBadRequest,
543+
},
544+
{
545+
name: "invalid - password missing special character",
546+
profile: dto.Profile{
547+
ProfileName: "test-profile",
548+
Activation: "acmactivate",
549+
AMTPassword: "Password123",
550+
GenerateRandomPassword: false,
551+
MEBXPassword: "P@ssw0rd123",
552+
GenerateRandomMEBxPassword: false,
553+
DHCPEnabled: true,
554+
UserConsent: "All",
555+
TenantID: "tenant1",
556+
},
557+
expectedCode: http.StatusBadRequest,
558+
},
559+
}
560+
561+
for _, tc := range tests {
562+
tc := tc
563+
564+
t.Run(tc.name, func(t *testing.T) {
565+
t.Parallel()
566+
567+
profileFeature, engine := profilesTest(t)
568+
569+
if tc.expectedCode == http.StatusCreated {
570+
profileFeature.EXPECT().Insert(context.Background(), &tc.profile).Return(&tc.profile, nil)
571+
}
572+
573+
reqBody, _ := json.Marshal(tc.profile)
574+
req, err := http.NewRequestWithContext(
575+
context.Background(),
576+
http.MethodPost,
577+
"/api/v1/admin/profiles",
578+
bytes.NewBuffer(reqBody),
579+
)
580+
require.NoError(t, err)
581+
req.Header.Set("Content-Type", "application/json")
582+
583+
w := httptest.NewRecorder()
584+
engine.ServeHTTP(w, req)
585+
586+
require.Equal(t, tc.expectedCode, w.Code)
587+
})
588+
}
589+
}
590+
591+
func stringPtr(s string) *string {
592+
return &s
593+
}

internal/entity/dto/v1/profile.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type Profile struct {
2121
DHCPEnabled bool `json:"dhcpEnabled" example:"true"`
2222
IPSyncEnabled bool `json:"ipSyncEnabled" example:"true"`
2323
LocalWiFiSyncEnabled bool `json:"localWifiSyncEnabled" example:"true"`
24-
WiFiConfigs []ProfileWiFiConfigs `json:"wifiConfigs,omitempty" binding:"excluded_if=DHCPEnabled false,dive"`
24+
WiFiConfigs []ProfileWiFiConfigs `json:"wifiConfigs,omitempty" binding:"wifidhcp,dive"`
2525
TenantID string `json:"tenantId" example:"abc123"`
2626
TLSMode int `json:"tlsMode,omitempty" binding:"omitempty,min=1,max=4,ciraortls" example:"1"`
2727
TLSCerts *TLSCerts `json:"tlsCerts,omitempty"`
@@ -66,6 +66,18 @@ var ValidateUserConsent validator.Func = func(fl validator.FieldLevel) bool {
6666
return userConsent == "none" || userConsent == "kvm" || userConsent == "all"
6767
}
6868

69+
var ValidateWiFiDHCP validator.Func = func(fl validator.FieldLevel) bool {
70+
dhcpEnabled := fl.Parent().FieldByName("DHCPEnabled").Bool()
71+
wifiConfigs := fl.Field()
72+
73+
// If WiFiConfigs has items and DHCP is disabled, fail validation
74+
if wifiConfigs.Len() > 0 && !dhcpEnabled {
75+
return false
76+
}
77+
78+
return true
79+
}
80+
6981
type ProfileCountResponse struct {
7082
Count int `json:"totalCount"`
7183
Data []Profile `json:"data"`

0 commit comments

Comments
 (0)