Skip to content

Commit 0025d07

Browse files
authored
update
1 parent 28bd838 commit 0025d07

5 files changed

Lines changed: 117 additions & 55 deletions

File tree

infra/conf/transport_internet.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,6 @@ func (c *TLSConfig) Build() (proto.Message, error) {
486486
}
487487
config.VerifyPeerCertInNames = c.VerifyPeerCertInNames
488488

489-
config.EchConfigList = c.ECHConfigList
490-
491489
if c.ECHServerKeys != "" {
492490
EchPrivateKey, err := base64.StdEncoding.DecodeString(c.ECHServerKeys)
493491
if err != nil {
@@ -496,6 +494,7 @@ func (c *TLSConfig) Build() (proto.Message, error) {
496494
config.EchServerKeys = EchPrivateKey
497495
}
498496
config.EchForceQuery = c.ECHForceQuery
497+
config.EchConfigList = c.ECHConfigList
499498

500499
return config, nil
501500
}

transport/internet/tls/config.pb.go

Lines changed: 14 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

transport/internet/tls/config.proto

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ message Config {
9292
*/
9393
repeated string verify_peer_cert_in_names = 17;
9494

95-
string ech_config_list = 18;
95+
bytes ech_server_keys = 18;
9696

97-
bytes ech_server_keys = 19;
98-
}
97+
string ech_config_list = 19;
98+
99+
bool ech_force_query = 20;
100+
}

transport/internet/tls/ech.go

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,23 @@ func ApplyECH(c *Config, config *tls.Config) error {
3232
nameToQuery := c.ServerName
3333
var DNSServer string
3434

35+
// for server
36+
if len(c.EchServerKeys) != 0 {
37+
KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
38+
if err != nil {
39+
return errors.New("Failed to unmarshal ECHKeySetList: ", err)
40+
}
41+
config.EncryptedClientHelloKeys = KeySets
42+
}
43+
3544
// for client
3645
if len(c.EchConfigList) != 0 {
3746
defer func() {
3847
// if failed to get ECHConfig, use an invalid one to make connection fail
39-
if err != nil && c.EchForceQuery {
40-
ECHConfig = []byte{1, 1, 4, 5, 1, 4}
48+
if err != nil {
49+
if c.EchForceQuery {
50+
ECHConfig = []byte{1, 1, 4, 5, 1, 4}
51+
}
4152
}
4253
config.EncryptedClientHelloConfigList = ECHConfig
4354
}()
@@ -58,7 +69,7 @@ func ApplyECH(c *Config, config *tls.Config) error {
5869
if nameToQuery == "" {
5970
return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query")
6071
}
61-
ECHConfig, err = QueryRecord(nameToQuery, DNSServer)
72+
ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchForceQuery)
6273
if err != nil {
6374
return err
6475
}
@@ -70,15 +81,6 @@ func ApplyECH(c *Config, config *tls.Config) error {
7081
}
7182
}
7283

73-
// for server
74-
if len(c.EchServerKeys) != 0 {
75-
KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
76-
if err != nil {
77-
return errors.New("Failed to unmarshal ECHKeySetList: ", err)
78-
}
79-
config.EncryptedClientHelloKeys = KeySets
80-
}
81-
8284
return nil
8385
}
8486

@@ -91,17 +93,19 @@ type ECHConfigCache struct {
9193
type echConfigRecord struct {
9294
config []byte
9395
expire time.Time
96+
err error
9497
}
9598

9699
var (
100+
// key value must be like this: "example.com|udp://1.1.1.1"
97101
GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
98102
clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]()
99103
)
100104

101105
// Update updates the ECH config for given domain and server.
102106
// this method is concurrent safe, only one update request will be sent, others get the cache.
103107
// if isLockedUpdate is true, it will not try to acquire the lock.
104-
func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool) ([]byte, error) {
108+
func (c *ECHConfigCache) Update(domain string, server string, forceQuery bool, isLockedUpdate bool) ([]byte, error) {
105109
if !isLockedUpdate {
106110
c.UpdateLock.Lock()
107111
defer c.UpdateLock.Unlock()
@@ -110,13 +114,23 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
110114
configRecord := c.configRecord.Load()
111115
if configRecord.expire.After(time.Now()) {
112116
errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
113-
return configRecord.config, nil
117+
return configRecord.config, configRecord.err
114118
}
115119
// Query ECH config from DNS server
116120
errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
117121
echConfig, ttl, err := dnsQuery(server, domain)
118122
if err != nil {
119-
return nil, err
123+
if forceQuery {
124+
return nil, err
125+
} else {
126+
configRecord = &echConfigRecord{
127+
config: nil,
128+
expire: time.Now().Add(10 * time.Minute),
129+
err: err,
130+
}
131+
c.configRecord.Store(configRecord)
132+
return echConfig, err
133+
}
120134
}
121135
configRecord = &echConfigRecord{
122136
config: echConfig,
@@ -128,30 +142,31 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
128142

129143
// QueryRecord returns the ECH config for given domain.
130144
// If the record is not in cache or expired, it will query the DNS server and update the cache.
131-
func QueryRecord(domain string, server string) ([]byte, error) {
132-
echConfigCache, ok := GlobalECHConfigCache.Load(domain)
145+
func QueryRecord(domain string, server string, forceQuery bool) ([]byte, error) {
146+
GlobalECHConfigCacheKey := domain + "|" + server
147+
echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey)
133148
if !ok {
134149
echConfigCache = &ECHConfigCache{}
135150
echConfigCache.configRecord.Store(&echConfigRecord{})
136-
echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(domain, echConfigCache)
151+
echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache)
137152
}
138153
configRecord := echConfigCache.configRecord.Load()
139154
if configRecord.expire.After(time.Now()) {
140155
errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
141-
return configRecord.config, nil
156+
return configRecord.config, configRecord.err
142157
}
143158

144159
// If expire is zero value, it means we are in initial state, wait for the query to finish
145160
// otherwise return old value immediately and update in a goroutine
146161
// but if the cache is too old, wait for update
147162
if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) {
148-
return echConfigCache.Update(domain, server, false)
163+
return echConfigCache.Update(domain, server, false, forceQuery)
149164
} else {
150165
// If someone already acquired the lock, it means it is updating, do not start another update goroutine
151166
if echConfigCache.UpdateLock.TryLock() {
152167
go func() {
153168
defer echConfigCache.UpdateLock.Unlock()
154-
echConfigCache.Update(domain, server, true)
169+
echConfigCache.Update(domain, server, true, forceQuery)
155170
}()
156171
}
157172
return configRecord.config, nil
@@ -170,7 +185,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
170185
m.Id = 0
171186
msg, err := m.Pack()
172187
if err != nil {
173-
return []byte{}, 0, err
188+
return nil, 0, err
174189
}
175190
var client *http.Client
176191
if client, _ = clientForECHDOH.Load(server); client == nil {
@@ -199,20 +214,20 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
199214
}
200215
req, err := http.NewRequest("POST", server, bytes.NewReader(msg))
201216
if err != nil {
202-
return []byte{}, 0, err
217+
return nil, 0, err
203218
}
204219
req.Header.Set("Content-Type", "application/dns-message")
205220
resp, err := client.Do(req)
206221
if err != nil {
207-
return []byte{}, 0, err
222+
return nil, 0, err
208223
}
209224
defer resp.Body.Close()
210225
respBody, err := io.ReadAll(resp.Body)
211226
if err != nil {
212-
return []byte{}, 0, err
227+
return nil, 0, err
213228
}
214229
if resp.StatusCode != http.StatusOK {
215-
return []byte{}, 0, errors.New("query failed with response code:", resp.StatusCode)
230+
return nil, 0, errors.New("query failed with response code:", resp.StatusCode)
216231
}
217232
dnsResolve = respBody
218233
} else if strings.HasPrefix(server, "udp://") { // for classic udp dns server
@@ -236,25 +251,25 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
236251
}
237252
}()
238253
if err != nil {
239-
return []byte{}, 0, err
254+
return nil, 0, err
240255
}
241256
msg, err := m.Pack()
242257
if err != nil {
243-
return []byte{}, 0, err
258+
return nil, 0, err
244259
}
245260
conn.Write(msg)
246261
udpResponse := make([]byte, 512)
247262
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
248263
_, err = conn.Read(udpResponse)
249264
if err != nil {
250-
return []byte{}, 0, err
265+
return nil, 0, err
251266
}
252267
dnsResolve = udpResponse
253268
}
254269
respMsg := new(dns.Msg)
255270
err := respMsg.Unpack(dnsResolve)
256271
if err != nil {
257-
return []byte{}, 0, errors.New("failed to unpack dns response for ECH: ", err)
272+
return nil, 0, errors.New("failed to unpack dns response for ECH: ", err)
258273
}
259274
if len(respMsg.Answer) > 0 {
260275
for _, answer := range respMsg.Answer {
@@ -268,7 +283,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
268283
}
269284
}
270285
}
271-
return []byte{}, 0, errors.New("no ech record found")
286+
return nil, 0, errors.New("no ech record found")
272287
}
273288

274289
// reference github.com/OmarTariq612/goech

transport/internet/tls/ech_test.go

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package tls_test
1+
package tls
22

33
import (
44
"io"
@@ -8,13 +8,12 @@ import (
88
"testing"
99

1010
"github.com/xtls/xray-core/common"
11-
. "github.com/xtls/xray-core/transport/internet/tls"
1211
)
1312

1413
func TestECHDial(t *testing.T) {
1514
config := &Config{
16-
ServerName: "encryptedsni.com",
17-
EchConfigList: "udp://1.1.1.1",
15+
ServerName: "cloudflare.com",
16+
EchConfigList: "encryptedsni.com+udp://1.1.1.1",
1817
}
1918
// test concurrent Dial(to test cache problem)
2019
wg := sync.WaitGroup{}
@@ -28,7 +27,7 @@ func TestECHDial(t *testing.T) {
2827
TLSClientConfig: TLSConfig,
2928
},
3029
}
31-
resp, err := client.Get("https://encryptedsni.com/cdn-cgi/trace")
30+
resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
3231
common.Must(err)
3332
defer resp.Body.Close()
3433
body, err := io.ReadAll(resp.Body)
@@ -40,4 +39,51 @@ func TestECHDial(t *testing.T) {
4039
}()
4140
}
4241
wg.Wait()
42+
// check cache
43+
echConfigCache, ok := GlobalECHConfigCache.Load("encryptedsni.com|udp://1.1.1.1")
44+
if !ok {
45+
t.Error("ECH config cache not found")
46+
47+
}
48+
ok = echConfigCache.UpdateLock.TryLock()
49+
if !ok {
50+
t.Error("ECH config cache dead lock detected")
51+
}
52+
echConfigCache.UpdateLock.Unlock()
53+
configRecord := echConfigCache.configRecord.Load()
54+
if configRecord == nil {
55+
t.Error("ECH config record not found in cache")
56+
}
57+
}
58+
59+
func TestECHDialFail(t *testing.T) {
60+
config := &Config{
61+
ServerName: "cloudflare.com",
62+
EchConfigList: "udp://1.1.1.1",
63+
}
64+
TLSConfig := config.GetTLSConfig()
65+
TLSConfig.NextProtos = []string{"http/1.1"}
66+
client := &http.Client{
67+
Transport: &http.Transport{
68+
TLSClientConfig: TLSConfig,
69+
},
70+
}
71+
resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
72+
common.Must(err)
73+
defer resp.Body.Close()
74+
_, err = io.ReadAll(resp.Body)
75+
common.Must(err)
76+
// check cache
77+
echConfigCache, ok := GlobalECHConfigCache.Load("cloudflare.com|udp://1.1.1.1")
78+
if !ok {
79+
t.Error("ECH config cache not found")
80+
}
81+
configRecord := echConfigCache.configRecord.Load()
82+
if configRecord == nil {
83+
t.Error("ECH config record not found in cache")
84+
return
85+
}
86+
if configRecord.err == nil {
87+
t.Error("unexpected nil error in ECH config record")
88+
}
4389
}

0 commit comments

Comments
 (0)