1111#include <pb_encode.h>
1212#include <squareup/subzero/internal.pb.h>
1313
14- static size_t get_serialized_message_size (const InternalCommandRequest * const cmd ) {
14+ // Helper which returns the size of a buffer that would be needed to hold the serialized version of the given
15+ // protobuf structure, assuming that pb_encode_delimited() serialization will be used.
16+ static size_t get_serialized_proto_struct_size (const pb_field_t fields [], const void * const proto_struct ) {
1517 pb_ostream_t stream = PB_OSTREAM_SIZING ;
16- if (!pb_encode_delimited (& stream , InternalCommandRequest_fields , cmd )) {
18+ if (!pb_encode_delimited (& stream , fields , proto_struct )) {
1719 ERROR ("%s: pb_encode_delimited() failed: %s" , __func__ , PB_GET_ERROR (& stream ));
1820 return 0 ;
1921 }
2022 return stream .bytes_written ;
2123}
2224
25+ // Serializes the given protobuf structure to the given buffer of the given size, using pb_encode_delimited().
26+ // Returns true on success or false on failure.
27+ // Caller brings their own buffer memory, this function does not allocate.
28+ static bool serialize_to_buffer (
29+ void * const buffer ,
30+ const size_t buffer_size ,
31+ const pb_field_t fields [],
32+ const void * const proto_struct ) {
33+ pb_ostream_t ostream = pb_ostream_from_buffer (buffer , buffer_size );
34+ if (!pb_encode_delimited (& ostream , fields , proto_struct )) {
35+ ERROR ("%s: pb_encode_delimited() failed: %s" , __func__ , PB_GET_ERROR (& ostream ));
36+ return false;
37+ }
38+ return true;
39+ }
40+
41+ // Deserializes the given buffer into the given protobuf structure, using pb_decode_delimited().
42+ // Returns true on success or false on failure.
43+ // Caller brings their own protobuf structure, this function does not allocate.
44+ static bool deserialize_from_buffer (
45+ const void * const buffer ,
46+ const size_t buffer_size ,
47+ const pb_field_t fields [],
48+ void * const proto_struct ) {
49+ pb_istream_t istream = pb_istream_from_buffer (buffer , buffer_size );
50+ if (!pb_decode_delimited (& istream , fields , proto_struct )) {
51+ ERROR ("%s: pb_decode_delimited() failed: %s" , __func__ , PB_GET_ERROR (& istream ));
52+ return false;
53+ }
54+ return true;
55+ }
56+
2357int verify_rpc_oversized_message_rejected (void ) {
2458 int result = 0 ;
2559 uint8_t * serialized_request = NULL ;
2660 uint8_t * serialized_response = NULL ;
2761
62+ // Construct an initial InternalCommandRequest which holds an InitWallet command
63+ // with a maximum-allowed-length random_bytes field.
2864 InternalCommandRequest cmd = InternalCommandRequest_init_default ;
2965 cmd .version = VERSION ;
3066 cmd .wallet_id = 1 ; // dummy value
@@ -35,12 +71,15 @@ int verify_rpc_oversized_message_rejected(void) {
3571 cmd .command .InitWallet .random_bytes .size = MASTER_SEED_SIZE ;
3672 random_buffer (cmd .command .InitWallet .random_bytes .bytes , MASTER_SEED_SIZE );
3773
38- size_t serialized_size = get_serialized_message_size (& cmd );
74+ // Compute the size of the serialized struct.
75+ size_t serialized_size = get_serialized_proto_struct_size (InternalCommandRequest_fields , & cmd );
3976 if (serialized_size == 0 ) {
77+ ERROR ("%s: error computing serialized request size" , __func__ );
4078 result = -1 ;
4179 goto out ;
4280 }
4381
82+ // Allocate a buffer to hold the serialized struct.
4483 // Note that we allocate 1 extra byte because we'll be extending the message.
4584 serialized_request = (uint8_t * ) calloc (1 , serialized_size + 1 );
4685 if (NULL == serialized_request ) {
@@ -49,30 +88,30 @@ int verify_rpc_oversized_message_rejected(void) {
4988 goto out ;
5089 }
5190
52- pb_ostream_t ostream = pb_ostream_from_buffer ( serialized_request , serialized_size );
53- if (!pb_encode_delimited ( & ostream , InternalCommandRequest_fields , & cmd )) {
54- ERROR ("%s: pb_encode_delimited () failed: %s " , __func__ , PB_GET_ERROR ( & ostream ) );
91+ // Serialize the struct to a byte array.
92+ if (!serialize_to_buffer ( serialized_request , serialized_size , InternalCommandRequest_fields , & cmd )) {
93+ ERROR ("%s: serialize_to_buf () failed" , __func__ );
5594 result = -1 ;
5695 goto out ;
5796 }
5897
5998 // Helper macro used to check our assumptions in the gnarly protobuf mangling code below
60- #define ASSERT_BYTE_EQUALS (buf , idx , expected_val ) \
61- do { \
62- const uint8_t* buf_ = (buf); \
63- const size_t idx_ = (idx); \
64- const uint8_t expected_val_ = (expected_val); \
65- const uint8_t actual_val_ = buf_[idx_]; \
66- if (actual_val_ != expected_val_) { \
67- ERROR( \
68- "%s: buf[%zu] contains unexpected value: %hhu, expected: %hhu", \
69- __func__, \
70- idx_, \
71- actual_val_, \
72- expected_val_); \
73- result = -1; \
74- goto out; \
75- } \
99+ #define ASSERT_BYTE_EQUALS (buf , idx , expected_val ) \
100+ do { \
101+ const uint8_t* buf_ = (buf); \
102+ const size_t idx_ = (idx); \
103+ const uint8_t expected_val_ = (expected_val); \
104+ const uint8_t actual_val_ = buf_[idx_]; \
105+ if (actual_val_ != expected_val_) { \
106+ ERROR( \
107+ "%s: buf[%zu] contains an unexpected value: %hhu, expected: %hhu", \
108+ __func__, \
109+ idx_, \
110+ actual_val_, \
111+ expected_val_); \
112+ result = -1; \
113+ goto out; \
114+ } \
76115 } while (0)
77116
78117 // Corrupt the message by making the random_bytes field 1 byte longer than the max allowed size.
@@ -87,23 +126,23 @@ int verify_rpc_oversized_message_rejected(void) {
87126 // length will actually take more than 1 byte, shifting everything after
88127 // it by a byte.
89128 // *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
90- // serialized_request[1] - field id (1 << 3) + tag (0) for field 1 (version). Should equal 0x08 .
129+ // serialized_request[1] - field id (1 << 3) + tag (0) for field 1 (version). Should equal 8 .
91130 // serialized_request[2..3] - varint-encoded value for field 1. Leave this alone, it's the
92131 // contents of the 'version' field (210 at the time of writing). If
93132 // version ever exceeds 16383, this will start taking up an extra byte
94133 // and shift everything after it by a byte.
95- // serialized_request[4] - field id (2 << 3) + tag (0) for field 2 (wallet_id). Should equal 0x10 .
96- // serialized_request[5] - varint-encoded value for field 2. Leave this allone , it's the dummy
97- // 'wallet' field which we set to 1 above. Should equal 0x01 .
134+ // serialized_request[4] - field id (2 << 3) + tag (0) for field 2 (wallet_id). Should equal 16 .
135+ // serialized_request[5] - varint-encoded value for field 2. Leave this alone , it's the dummy
136+ // 'wallet' field which we set to 1 above. Should equal 1 .
98137 // serialized_request[6] - field id (5 << 3) + tag (2, for 'LEN') for field 5 (command.InitWallet).
99- // Should equal 0x2a .
138+ // Should equal 42 .
100139 // serialized_request[7] - varint-encoded LEN of the InitWalletRequest submessage.
101- // Should equal 0x42 (decimal 66) .
140+ // Should equal 66 .
102141 // *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
103142 // serialized_request[8] - field id (1 << 3) + tag (2, for 'LEN') for field 1 of sub-message.
104- // Should equal 0x0a .
143+ // Should equal 10 .
105144 // serialized_request[9] - varint-encoded LEN of field 1 (random_bytes) of sub-message.
106- // Should equal 0x40 (decimal 64) .
145+ // Should equal 64 .
107146 // *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
108147 // serialized_request[10..73] - the contents of the random_bytes field. Should be 64 bytes in length.
109148 // serialized_request[74] - doesn't exist in the original message. We add an extra data byte here.
@@ -130,17 +169,21 @@ int verify_rpc_oversized_message_rejected(void) {
130169 serialized_request [7 ]++ ; // increment LEN byte for top-level field 5
131170 serialized_request [9 ]++ ; // increment LEN byte for nested field 1
132171 serialized_request [serialized_size ] = 0xaa ; // set the last byte to an arbitrary value
172+ serialized_size ++ ; // increment serialized_size since we added a byte of data
133173
134- pb_istream_t istream = pb_istream_from_buffer ( serialized_request , serialized_size + 1 );
174+ // Allocate a buffer for the serialized response.
135175 const size_t response_buffer_size = 2048 ; // 2048 bytes should be more than enough
136176 serialized_response = (uint8_t * ) calloc (1 , response_buffer_size );
137177 if (NULL == serialized_response ) {
138178 ERROR ("%s: calloc(1, %zu) failed" , __func__ , response_buffer_size );
139179 result = -1 ;
140180 goto out ;
141181 }
142- pb_ostream_t ostream2 = pb_ostream_from_buffer (serialized_response , response_buffer_size );
143- ERROR ("(next line is expected to show red text...)" );
182+
183+ // Create a stream which will read from the corrupted serialized buffer.
184+ pb_istream_t istream = pb_istream_from_buffer (serialized_request , serialized_size );
185+ // Create a stream which will write to the response buffer.
186+ pb_ostream_t ostream = pb_ostream_from_buffer (serialized_response , response_buffer_size );
144187
145188 // Now that we have a serialized buffer, try to pass it to handle_incoming_message().
146189 // This should fail because the InitWallet.random_bytes field has a length of 65 bytes,
@@ -149,23 +192,26 @@ int verify_rpc_oversized_message_rejected(void) {
149192 // NOTE: when building for nCipher, there are command hooks that would reject the command
150193 // because it's missing the tickets for key use authorization. But this doesn't matter for
151194 // this test case, because the protobuf parsing happens before that and fails first.
152- handle_incoming_message (& istream , & ostream2 );
153- const size_t actual_response_size = ostream2 .bytes_written ;
195+ ERROR ("(next line is expected to show red text...)" );
196+
197+ handle_incoming_message (& istream , & ostream ); // <---- this is the actual function under test
198+
199+ // Extract the response structure from the serialized_response buffer. It should be an error.
200+ const size_t actual_response_size = ostream .bytes_written ;
154201 if (actual_response_size == 0 ) {
155- ERROR ("%s: no response received from handle_incoming_message(): %s" , __func__ , PB_GET_ERROR (& ostream2 ));
202+ ERROR ("%s: no response received from handle_incoming_message(): %s" , __func__ , PB_GET_ERROR (& ostream ));
156203 result = -1 ;
157204 goto out ;
158205 }
159- pb_istream_t istream2 = pb_istream_from_buffer (serialized_response , actual_response_size );
160- InternalCommandResponse response ; // note: no need to init, pb_decode_delimited() does it
161- if (!pb_decode_delimited (& istream2 , InternalCommandResponse_fields , & response )) {
162- ERROR (
163- "%s: pb_decode_delimited(..., InternalCommandResponse_fields, ...) failed: %s" ,
164- __func__ ,
165- PB_GET_ERROR (& istream2 ));
206+
207+ InternalCommandResponse response ; // note: no need to init, deserialize_from_buf() does it via pb_decode_delimited().
208+ if (!deserialize_from_buffer (serialized_response , actual_response_size , InternalCommandResponse_fields , & response )) {
209+ ERROR ("%s: deserialize_from_buf() failed" , __func__ );
166210 result = -1 ;
167211 goto out ;
168212 }
213+
214+ // Check that the response contains an error.
169215 if (response .which_response != InternalCommandResponse_Error_tag ) {
170216 ERROR (
171217 "%s: wrong response tag: %d, expected: %d" ,
@@ -175,6 +221,8 @@ int verify_rpc_oversized_message_rejected(void) {
175221 result = -1 ;
176222 goto out ;
177223 }
224+
225+ // Check that the error response contains the expected error code.
178226 if (response .response .Error .code != Result_COMMAND_DECODE_FAILED ) {
179227 ERROR (
180228 "%s: wrong response error code: %d, expected: %d" ,
@@ -184,11 +232,15 @@ int verify_rpc_oversized_message_rejected(void) {
184232 result = -1 ;
185233 goto out ;
186234 }
235+
236+ // Check that the error response contains some message.
187237 if (!response .response .Error .has_message ) {
188238 ERROR ("%s: error response does not contain a 'message' field" , __func__ );
189239 result = -1 ;
190240 goto out ;
191241 }
242+
243+ // Check that the error response contains the expected message.
192244 if (0 != strcmp ("Decode Input failed: bytes overflow" , response .response .Error .message )) {
193245 ERROR ("%s: error response contains unexpected message: %s" , __func__ , response .response .Error .message );
194246 result = -1 ;
@@ -198,5 +250,8 @@ int verify_rpc_oversized_message_rejected(void) {
198250out :
199251 free (serialized_request );
200252 free (serialized_response );
253+ if (result == 0 ) {
254+ INFO ("%s: ok" , __func__ );
255+ }
201256 return result ;
202257}
0 commit comments