@@ -32,12 +32,23 @@ func ApplyECH(c *Config, config *tls.Config) error {
3232 nameToQuery := c .ServerName
3333 var DNSServer string
3434
35+ // for server
36+ if len (c .EchServerKeys ) != 0 {
37+ KeySets , err := ConvertToGoECHKeys (c .EchServerKeys )
38+ if err != nil {
39+ return errors .New ("Failed to unmarshal ECHKeySetList: " , err )
40+ }
41+ config .EncryptedClientHelloKeys = KeySets
42+ }
43+
3544 // for client
3645 if len (c .EchConfigList ) != 0 {
3746 defer func () {
3847 // if failed to get ECHConfig, use an invalid one to make connection fail
39- if err != nil && c .EchForceQuery {
40- ECHConfig = []byte {1 , 1 , 4 , 5 , 1 , 4 }
48+ if err != nil {
49+ if c .EchForceQuery {
50+ ECHConfig = []byte {1 , 1 , 4 , 5 , 1 , 4 }
51+ }
4152 }
4253 config .EncryptedClientHelloConfigList = ECHConfig
4354 }()
@@ -58,7 +69,7 @@ func ApplyECH(c *Config, config *tls.Config) error {
5869 if nameToQuery == "" {
5970 return errors .New ("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query" )
6071 }
61- ECHConfig , err = QueryRecord (nameToQuery , DNSServer )
72+ ECHConfig , err = QueryRecord (nameToQuery , DNSServer , c . EchForceQuery )
6273 if err != nil {
6374 return err
6475 }
@@ -70,15 +81,6 @@ func ApplyECH(c *Config, config *tls.Config) error {
7081 }
7182 }
7283
73- // for server
74- if len (c .EchServerKeys ) != 0 {
75- KeySets , err := ConvertToGoECHKeys (c .EchServerKeys )
76- if err != nil {
77- return errors .New ("Failed to unmarshal ECHKeySetList: " , err )
78- }
79- config .EncryptedClientHelloKeys = KeySets
80- }
81-
8284 return nil
8385}
8486
@@ -91,17 +93,19 @@ type ECHConfigCache struct {
9193type echConfigRecord struct {
9294 config []byte
9395 expire time.Time
96+ err error
9497}
9598
9699var (
100+ // key value must be like this: "example.com|udp://1.1.1.1"
97101 GlobalECHConfigCache = utils .NewTypedSyncMap [string , * ECHConfigCache ]()
98102 clientForECHDOH = utils .NewTypedSyncMap [string , * http.Client ]()
99103)
100104
101105// Update updates the ECH config for given domain and server.
102106// this method is concurrent safe, only one update request will be sent, others get the cache.
103107// if isLockedUpdate is true, it will not try to acquire the lock.
104- func (c * ECHConfigCache ) Update (domain string , server string , isLockedUpdate bool ) ([]byte , error ) {
108+ func (c * ECHConfigCache ) Update (domain string , server string , forceQuery bool , isLockedUpdate bool ) ([]byte , error ) {
105109 if ! isLockedUpdate {
106110 c .UpdateLock .Lock ()
107111 defer c .UpdateLock .Unlock ()
@@ -110,13 +114,23 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
110114 configRecord := c .configRecord .Load ()
111115 if configRecord .expire .After (time .Now ()) {
112116 errors .LogDebug (context .Background (), "Cache hit for domain after double check: " , domain )
113- return configRecord .config , nil
117+ return configRecord .config , configRecord . err
114118 }
115119 // Query ECH config from DNS server
116120 errors .LogDebug (context .Background (), "Trying to query ECH config for domain: " , domain , " with ECH server: " , server )
117121 echConfig , ttl , err := dnsQuery (server , domain )
118122 if err != nil {
119- return nil , err
123+ if forceQuery {
124+ return nil , err
125+ } else {
126+ configRecord = & echConfigRecord {
127+ config : nil ,
128+ expire : time .Now ().Add (10 * time .Minute ),
129+ err : err ,
130+ }
131+ c .configRecord .Store (configRecord )
132+ return echConfig , err
133+ }
120134 }
121135 configRecord = & echConfigRecord {
122136 config : echConfig ,
@@ -128,30 +142,31 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
128142
129143// QueryRecord returns the ECH config for given domain.
130144// If the record is not in cache or expired, it will query the DNS server and update the cache.
131- func QueryRecord (domain string , server string ) ([]byte , error ) {
132- echConfigCache , ok := GlobalECHConfigCache .Load (domain )
145+ func QueryRecord (domain string , server string , forceQuery bool ) ([]byte , error ) {
146+ GlobalECHConfigCacheKey := domain + "|" + server
147+ echConfigCache , ok := GlobalECHConfigCache .Load (GlobalECHConfigCacheKey )
133148 if ! ok {
134149 echConfigCache = & ECHConfigCache {}
135150 echConfigCache .configRecord .Store (& echConfigRecord {})
136- echConfigCache , _ = GlobalECHConfigCache .LoadOrStore (domain , echConfigCache )
151+ echConfigCache , _ = GlobalECHConfigCache .LoadOrStore (GlobalECHConfigCacheKey , echConfigCache )
137152 }
138153 configRecord := echConfigCache .configRecord .Load ()
139154 if configRecord .expire .After (time .Now ()) {
140155 errors .LogDebug (context .Background (), "Cache hit for domain: " , domain )
141- return configRecord .config , nil
156+ return configRecord .config , configRecord . err
142157 }
143158
144159 // If expire is zero value, it means we are in initial state, wait for the query to finish
145160 // otherwise return old value immediately and update in a goroutine
146161 // but if the cache is too old, wait for update
147162 if configRecord .expire == (time.Time {}) || configRecord .expire .Add (time .Hour * 6 ).Before (time .Now ()) {
148- return echConfigCache .Update (domain , server , false )
163+ return echConfigCache .Update (domain , server , false , forceQuery )
149164 } else {
150165 // If someone already acquired the lock, it means it is updating, do not start another update goroutine
151166 if echConfigCache .UpdateLock .TryLock () {
152167 go func () {
153168 defer echConfigCache .UpdateLock .Unlock ()
154- echConfigCache .Update (domain , server , true )
169+ echConfigCache .Update (domain , server , true , forceQuery )
155170 }()
156171 }
157172 return configRecord .config , nil
@@ -170,7 +185,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
170185 m .Id = 0
171186 msg , err := m .Pack ()
172187 if err != nil {
173- return [] byte {} , 0 , err
188+ return nil , 0 , err
174189 }
175190 var client * http.Client
176191 if client , _ = clientForECHDOH .Load (server ); client == nil {
@@ -199,20 +214,20 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
199214 }
200215 req , err := http .NewRequest ("POST" , server , bytes .NewReader (msg ))
201216 if err != nil {
202- return [] byte {} , 0 , err
217+ return nil , 0 , err
203218 }
204219 req .Header .Set ("Content-Type" , "application/dns-message" )
205220 resp , err := client .Do (req )
206221 if err != nil {
207- return [] byte {} , 0 , err
222+ return nil , 0 , err
208223 }
209224 defer resp .Body .Close ()
210225 respBody , err := io .ReadAll (resp .Body )
211226 if err != nil {
212- return [] byte {} , 0 , err
227+ return nil , 0 , err
213228 }
214229 if resp .StatusCode != http .StatusOK {
215- return [] byte {} , 0 , errors .New ("query failed with response code:" , resp .StatusCode )
230+ return nil , 0 , errors .New ("query failed with response code:" , resp .StatusCode )
216231 }
217232 dnsResolve = respBody
218233 } else if strings .HasPrefix (server , "udp://" ) { // for classic udp dns server
@@ -236,25 +251,25 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
236251 }
237252 }()
238253 if err != nil {
239- return [] byte {} , 0 , err
254+ return nil , 0 , err
240255 }
241256 msg , err := m .Pack ()
242257 if err != nil {
243- return [] byte {} , 0 , err
258+ return nil , 0 , err
244259 }
245260 conn .Write (msg )
246261 udpResponse := make ([]byte , 512 )
247262 conn .SetReadDeadline (time .Now ().Add (5 * time .Second ))
248263 _ , err = conn .Read (udpResponse )
249264 if err != nil {
250- return [] byte {} , 0 , err
265+ return nil , 0 , err
251266 }
252267 dnsResolve = udpResponse
253268 }
254269 respMsg := new (dns.Msg )
255270 err := respMsg .Unpack (dnsResolve )
256271 if err != nil {
257- return [] byte {} , 0 , errors .New ("failed to unpack dns response for ECH: " , err )
272+ return nil , 0 , errors .New ("failed to unpack dns response for ECH: " , err )
258273 }
259274 if len (respMsg .Answer ) > 0 {
260275 for _ , answer := range respMsg .Answer {
@@ -268,7 +283,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
268283 }
269284 }
270285 }
271- return [] byte {} , 0 , errors .New ("no ech record found" )
286+ return nil , 0 , errors .New ("no ech record found" )
272287}
273288
274289// reference github.com/OmarTariq612/goech
0 commit comments