Skip to content

Commit 4e35b36

Browse files
committed
more test coverage
1 parent 85f7011 commit 4e35b36

15 files changed

Lines changed: 2133 additions & 275 deletions

api/http_test.go

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
package api
2+
3+
import (
4+
"bytes"
5+
"encoding/hex"
6+
"encoding/json"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
"time"
11+
12+
"github.com/MixinNetwork/tip/crypto"
13+
"github.com/drand/kyber"
14+
"github.com/drand/kyber/pairing/bn256"
15+
"github.com/drand/kyber/share"
16+
"github.com/drand/kyber/share/dkg"
17+
"github.com/drand/kyber/util/random"
18+
"github.com/stretchr/testify/require"
19+
"github.com/unrolled/render"
20+
)
21+
22+
type stubStore struct {
23+
watchFn func([]byte) ([]byte, time.Time, int, error)
24+
}
25+
26+
func (s *stubStore) CheckPolyGroup([]byte) (bool, error) {
27+
return false, nil
28+
}
29+
30+
func (s *stubStore) ReadPolyPublic() ([]byte, error) {
31+
return nil, nil
32+
}
33+
34+
func (s *stubStore) ReadPolyShare() ([]byte, error) {
35+
return nil, nil
36+
}
37+
38+
func (s *stubStore) WritePoly([]byte, []byte) error {
39+
return nil
40+
}
41+
42+
func (s *stubStore) WriteAssignee([]byte, []byte) error {
43+
return nil
44+
}
45+
46+
func (s *stubStore) ReadAssignor([]byte) ([]byte, error) {
47+
return nil, nil
48+
}
49+
50+
func (s *stubStore) ReadAssignee([]byte) ([]byte, error) {
51+
return nil, nil
52+
}
53+
54+
func (s *stubStore) CheckLimit([]byte, time.Duration, uint32, bool) (int, error) {
55+
return 0, nil
56+
}
57+
58+
func (s *stubStore) CheckEphemeralNonce([]byte, []byte, uint64, time.Duration) (bool, error) {
59+
return false, nil
60+
}
61+
62+
func (s *stubStore) RotateEphemeralNonce([]byte, []byte, uint64) error {
63+
return nil
64+
}
65+
66+
func (s *stubStore) WriteSignRequest([]byte, []byte) (time.Time, int, error) {
67+
return time.Time{}, 0, nil
68+
}
69+
70+
func (s *stubStore) Watch(key []byte) ([]byte, time.Time, int, error) {
71+
if s.watchFn != nil {
72+
return s.watchFn(key)
73+
}
74+
return nil, time.Time{}, 0, nil
75+
}
76+
77+
func testScalar() kyber.Scalar {
78+
return bn256.NewSuiteG2().Scalar().Pick(random.New())
79+
}
80+
81+
func testHandler(key kyber.Scalar, store *stubStore) *Handler {
82+
signers := []dkg.Node{
83+
{Index: 0, Public: crypto.PublicKey(key)},
84+
{Index: 1, Public: crypto.PublicKey(testScalar())},
85+
}
86+
poly := []kyber.Point{
87+
crypto.PublicKey(testScalar()),
88+
crypto.PublicKey(testScalar()),
89+
}
90+
return &Handler{
91+
store: store,
92+
conf: &Configuration{
93+
Key: key,
94+
Signers: signers,
95+
Poly: poly,
96+
Share: &share.PriShare{I: 1},
97+
Port: 7000,
98+
},
99+
render: render.New(),
100+
}
101+
}
102+
103+
func TestInfoSignsPayload(t *testing.T) {
104+
require := require.New(t)
105+
106+
key := testScalar()
107+
signers := []dkg.Node{
108+
{Index: 0, Public: crypto.PublicKey(key)},
109+
{Index: 1, Public: crypto.PublicKey(testScalar())},
110+
}
111+
poly := []kyber.Point{
112+
crypto.PublicKey(testScalar()),
113+
crypto.PublicKey(testScalar()),
114+
}
115+
116+
data, sigHex := info(key, signers, poly)
117+
body, ok := data.(map[string]any)
118+
require.True(ok)
119+
require.Equal(crypto.PublicKeyString(crypto.PublicKey(key)), body["identity"])
120+
require.Len(body["signers"], len(signers))
121+
require.Len(body["commitments"], len(poly))
122+
123+
rawSig, err := hex.DecodeString(sigHex)
124+
require.NoError(err)
125+
126+
payload, err := json.Marshal(data)
127+
require.NoError(err)
128+
require.NoError(crypto.Verify(crypto.PublicKey(key), payload, rawSig))
129+
}
130+
131+
func TestWatchRejectsInvalidWatcher(t *testing.T) {
132+
require := require.New(t)
133+
134+
_, _, err := watch(&stubStore{}, "bad-watcher")
135+
require.Error(err)
136+
require.Contains(err.Error(), "invalid watcher")
137+
}
138+
139+
func TestWatchReturnsStoreValues(t *testing.T) {
140+
require := require.New(t)
141+
142+
wantWatcher := bytes.Repeat([]byte{0x7f}, 32)
143+
wantAssignor := []byte("assignor")
144+
wantGenesis := time.Unix(1700000000, 123)
145+
store := &stubStore{
146+
watchFn: func(key []byte) ([]byte, time.Time, int, error) {
147+
require.Equal(wantWatcher, key)
148+
return wantAssignor, wantGenesis, 3, nil
149+
},
150+
}
151+
152+
genesis, counter, err := watch(store, hex.EncodeToString(wantWatcher))
153+
require.NoError(err)
154+
require.True(wantGenesis.Equal(genesis))
155+
require.Equal(3, counter)
156+
}
157+
158+
func TestServeHTTPGetRoot(t *testing.T) {
159+
require := require.New(t)
160+
161+
key := testScalar()
162+
hdr := testHandler(key, &stubStore{})
163+
164+
req := httptest.NewRequest(http.MethodGet, "/", nil)
165+
rec := httptest.NewRecorder()
166+
hdr.ServeHTTP(rec, req)
167+
168+
require.Equal(http.StatusOK, rec.Code)
169+
170+
var body struct {
171+
Data map[string]any `json:"data"`
172+
Signature string `json:"signature"`
173+
Version string `json:"version"`
174+
}
175+
require.NoError(json.Unmarshal(rec.Body.Bytes(), &body))
176+
require.Equal("v0.2.0", body.Version)
177+
require.Equal(crypto.PublicKeyString(crypto.PublicKey(key)), body.Data["identity"])
178+
179+
rawSig, err := hex.DecodeString(body.Signature)
180+
require.NoError(err)
181+
payload, err := json.Marshal(body.Data)
182+
require.NoError(err)
183+
require.NoError(crypto.Verify(crypto.PublicKey(key), payload, rawSig))
184+
}
185+
186+
func TestServeHTTPHandlesErrorsAndWatchRequests(t *testing.T) {
187+
require := require.New(t)
188+
189+
watcher := bytes.Repeat([]byte{0x42}, 32)
190+
genesis := time.Unix(1701000000, 0)
191+
store := &stubStore{
192+
watchFn: func(key []byte) ([]byte, time.Time, int, error) {
193+
require.Equal(watcher, key)
194+
return []byte("assignor"), genesis, 9, nil
195+
},
196+
}
197+
hdr := testHandler(testScalar(), store)
198+
199+
notFoundReq := httptest.NewRequest(http.MethodGet, "/missing", nil)
200+
notFoundRec := httptest.NewRecorder()
201+
hdr.ServeHTTP(notFoundRec, notFoundReq)
202+
require.Equal(http.StatusNotFound, notFoundRec.Code)
203+
204+
invalidJSONReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("{"))
205+
invalidJSONRec := httptest.NewRecorder()
206+
hdr.ServeHTTP(invalidJSONRec, invalidJSONReq)
207+
require.Equal(http.StatusBadRequest, invalidJSONRec.Code)
208+
209+
watchReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"action":"WATCH","watcher":"`+hex.EncodeToString(watcher)+`"}`))
210+
watchRec := httptest.NewRecorder()
211+
hdr.ServeHTTP(watchRec, watchReq)
212+
require.Equal(http.StatusOK, watchRec.Code)
213+
214+
var watchBody struct {
215+
Genesis time.Time `json:"genesis"`
216+
Counter int `json:"counter"`
217+
}
218+
require.NoError(json.Unmarshal(watchRec.Body.Bytes(), &watchBody))
219+
require.True(genesis.Equal(watchBody.Genesis))
220+
require.Equal(9, watchBody.Counter)
221+
222+
invalidActionReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString(`{"action":"NOPE"}`))
223+
invalidActionRec := httptest.NewRecorder()
224+
hdr.ServeHTTP(invalidActionRec, invalidActionReq)
225+
require.Equal(http.StatusBadRequest, invalidActionRec.Code)
226+
}
227+
228+
func TestHandleCORSPassesThroughAndHandlesOptions(t *testing.T) {
229+
require := require.New(t)
230+
231+
called := false
232+
handler := handleCORS(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
233+
called = true
234+
w.WriteHeader(http.StatusCreated)
235+
}))
236+
237+
optionsReq := httptest.NewRequest(http.MethodOptions, "/", nil)
238+
optionsReq.Header.Set("Origin", "https://example.com")
239+
optionsRec := httptest.NewRecorder()
240+
handler.ServeHTTP(optionsRec, optionsReq)
241+
242+
require.False(called)
243+
require.Equal(http.StatusOK, optionsRec.Code)
244+
require.Equal("https://example.com", optionsRec.Header().Get("Access-Control-Allow-Origin"))
245+
require.Equal("Content-Type,X-Request-ID", optionsRec.Header().Get("Access-Control-Allow-Headers"))
246+
247+
getReq := httptest.NewRequest(http.MethodGet, "/", nil)
248+
getReq.Header.Set("Origin", "https://example.com")
249+
getRec := httptest.NewRecorder()
250+
handler.ServeHTTP(getRec, getReq)
251+
252+
require.True(called)
253+
require.Equal(http.StatusCreated, getRec.Code)
254+
require.Equal("https://example.com", getRec.Header().Get("Access-Control-Allow-Origin"))
255+
}

0 commit comments

Comments
 (0)