Skip to content

Commit 50fab5b

Browse files
author
Leonardo Arcari
committed
feat(rust): generate setters for optional enum fields too
1 parent 8d16682 commit 50fab5b

7 files changed

Lines changed: 184 additions & 37 deletions

File tree

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ tasks.register('generateRustTestCodecs', JavaExec) {
679679
'sbe-tool/src/test/resources/issue1028.xml',
680680
'sbe-tool/src/test/resources/issue1057.xml',
681681
'sbe-tool/src/test/resources/issue1066.xml',
682+
'sbe-tool/src/test/resources/optional_enum_nullify.xml',
682683
'sbe-tool/src/test/resources/fixed-sized-primitive-array-types.xml',
683684
'sbe-tool/src/test/resources/example-bigendian-test-schema.xml',
684685
'sbe-tool/src/test/resources/nested-composite-name.xml',

rust/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ issue_987 = { path = "../generated/rust/issue987" }
1919
issue_1028 = { path = "../generated/rust/issue1028" }
2020
issue_1057 = { path = "../generated/rust/issue1057" }
2121
issue_1066 = { path = "../generated/rust/issue1066" }
22+
optional_enum_nullify = { path = "../generated/rust/optional_enum_nullify" }
2223
baseline_bigendian = { path = "../generated/rust/baseline-bigendian" }
2324
nested_composite_name = { path = "../generated/rust/nested-composite-name" }
2425
fixed_sized_primitive_array = { path = "../generated/rust/fixed_sized_primitive_array" }

rust/tests/optional_field_nullify_test.rs

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
use examples_extension as ext;
22
use examples_uk_co_real_logic_sbe_benchmarks_fix as mqfix;
3-
use issue_895 as i895;
4-
use issue_972 as i972;
53
use ext::{
64
boolean_type::BooleanType as ExtBooleanType,
75
boost_type::BoostType as ExtBoostType,
@@ -12,9 +10,7 @@ use ext::{
1210
message_header_codec::{self as ext_header, MessageHeaderDecoder as ExtHeaderDecoder},
1311
model::Model as ExtModel,
1412
optional_extras::OptionalExtras as ExtOptionalExtras,
15-
ReadBuf as ExtReadBuf,
16-
SbeResult as ExtSbeResult,
17-
WriteBuf as ExtWriteBuf,
13+
ReadBuf as ExtReadBuf, SbeResult as ExtSbeResult, WriteBuf as ExtWriteBuf,
1814
};
1915
use i895::{
2016
issue_895_codec::{Issue895Decoder, Issue895Encoder},
@@ -26,6 +22,8 @@ use i972::{
2622
message_header_codec::{self as i972_header, MessageHeaderDecoder as I972HeaderDecoder},
2723
Either as I972Either, ReadBuf as I972ReadBuf, WriteBuf as I972WriteBuf,
2824
};
25+
use issue_895 as i895;
26+
use issue_972 as i972;
2927
use mqfix::{
3028
mass_quote_codec::{
3129
encoder::{QuoteEntriesEncoder, QuoteSetsEncoder},
@@ -34,15 +32,23 @@ use mqfix::{
3432
message_header_codec::{self as mqfix_header, MessageHeaderDecoder as MassQuoteHeaderDecoder},
3533
ReadBuf as MassQuoteReadBuf, WriteBuf as MassQuoteWriteBuf,
3634
};
35+
use oen::{
36+
enum_type::EnumType as OenEnumType,
37+
message_header_codec::{
38+
self as oen_header, MessageHeaderDecoder as OptionalEnumNullifyHeaderDecoder,
39+
},
40+
optional_encoding_enum_type::OptionalEncodingEnumType as OenOptionalEncodingEnumType,
41+
optional_enum_nullify_codec::{OptionalEnumNullifyDecoder, OptionalEnumNullifyEncoder},
42+
ReadBuf as OenReadBuf, WriteBuf as OenWriteBuf,
43+
};
44+
use optional_enum_nullify as oen;
3745

3846
type TestResult = Result<(), Box<dyn std::error::Error>>;
3947

4048
macro_rules! create_encoder_with_header_parent {
4149
($buffer:expr, $encoder_ty:ty, $write_buf_ty:ty, $encoded_len:expr) => {{
42-
let encoder = <$encoder_ty>::default().wrap(
43-
<$write_buf_ty>::new($buffer.as_mut_slice()),
44-
$encoded_len,
45-
);
50+
let encoder = <$encoder_ty>::default()
51+
.wrap(<$write_buf_ty>::new($buffer.as_mut_slice()), $encoded_len);
4652
let mut header = encoder.header(0);
4753
header.parent().unwrap()
4854
}};
@@ -66,6 +72,15 @@ fn create_issue_972_encoder(buffer: &mut Vec<u8>) -> Issue972Encoder<'_> {
6672
)
6773
}
6874

75+
fn create_optional_enum_nullify_encoder(buffer: &mut Vec<u8>) -> OptionalEnumNullifyEncoder<'_> {
76+
create_encoder_with_header_parent!(
77+
buffer,
78+
OptionalEnumNullifyEncoder<'_>,
79+
OenWriteBuf<'_>,
80+
oen_header::ENCODED_LENGTH
81+
)
82+
}
83+
6984
fn create_mass_quote_encoder(buffer: &mut Vec<u8>) -> MassQuoteEncoder<'_> {
7085
create_encoder_with_header_parent!(
7186
buffer,
@@ -87,6 +102,12 @@ fn decode_issue_972(buffer: &[u8]) -> Issue972Decoder<'_> {
87102
Issue972Decoder::default().header(header, 0)
88103
}
89104

105+
fn decode_optional_enum_nullify(buffer: &[u8]) -> OptionalEnumNullifyDecoder<'_> {
106+
let buf = OenReadBuf::new(buffer);
107+
let header = OptionalEnumNullifyHeaderDecoder::default().wrap(buf, 0);
108+
OptionalEnumNullifyDecoder::default().header(header, 0)
109+
}
110+
90111
fn encode_extension_car_for_optional_field_test(
91112
cup_holder_count: Option<u8>,
92113
nullify: bool,
@@ -356,3 +377,51 @@ fn group_nullify_optional_fields_sets_optional_fields_to_none() -> TestResult {
356377

357378
Ok(())
358379
}
380+
381+
#[test]
382+
fn nullify_optional_fields_sets_optional_enum_field_to_null_value() -> TestResult {
383+
use oen::Encoder;
384+
let mut control_buffer = vec![0u8; 256];
385+
let mut control_encoder = create_optional_enum_nullify_encoder(&mut control_buffer);
386+
// Optionality is declared on the field, matching IR optional enum semantics.
387+
control_encoder.optional_enum_opt(Some(OenEnumType::One));
388+
// This field is required, even if its enum type encoding is optional.
389+
control_encoder.required_enum_from_optional_type(OenOptionalEncodingEnumType::Alpha);
390+
391+
let control_decoder = decode_optional_enum_nullify(control_buffer.as_slice());
392+
assert_eq!(OenEnumType::One, control_decoder.optional_enum());
393+
assert_eq!(
394+
OenOptionalEncodingEnumType::Alpha,
395+
control_decoder.required_enum_from_optional_type()
396+
);
397+
398+
let mut none_buffer = vec![0u8; 256];
399+
let mut none_encoder = create_optional_enum_nullify_encoder(&mut none_buffer);
400+
none_encoder.optional_enum_opt(None);
401+
none_encoder.required_enum_from_optional_type(OenOptionalEncodingEnumType::Beta);
402+
403+
let none_decoder = decode_optional_enum_nullify(none_buffer.as_slice());
404+
assert_eq!(OenEnumType::NullVal, none_decoder.optional_enum());
405+
assert_eq!(
406+
OenOptionalEncodingEnumType::Beta,
407+
none_decoder.required_enum_from_optional_type()
408+
);
409+
410+
let mut nullified_buffer = vec![0u8; 256];
411+
let mut nullified_encoder = create_optional_enum_nullify_encoder(&mut nullified_buffer);
412+
nullified_encoder.optional_enum_opt(Some(OenEnumType::Two));
413+
nullified_encoder.required_enum_from_optional_type(OenOptionalEncodingEnumType::Beta);
414+
// This is the behavior under test: nullify should route through *_opt(None)
415+
// only for field-optional members, so `optional_enum` becomes NullVal while
416+
// the required field remains unchanged even if its enum type encoding is optional.
417+
nullified_encoder.nullify_optional_fields();
418+
419+
let nullified_decoder = decode_optional_enum_nullify(nullified_buffer.as_slice());
420+
assert_eq!(OenEnumType::NullVal, nullified_decoder.optional_enum());
421+
assert_eq!(
422+
OenOptionalEncodingEnumType::Beta,
423+
nullified_decoder.required_enum_from_optional_type()
424+
);
425+
426+
Ok(())
427+
}

sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/MessageCoderDef.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ void generate(
7272

7373
if (codecType == Encoder)
7474
{
75-
final List<String> optionalPrimitiveFields = RustGenerator.getOptionalPrimitiveFields(fields);
76-
RustGenerator.appendImplEncoderTrait(sb, msgTypeName, optionalPrimitiveFields);
75+
final List<String> optionalFields = RustGenerator.getNullifyOptionalFields(fields);
76+
RustGenerator.appendImplEncoderTrait(sb, msgTypeName, optionalFields);
7777
}
7878
else
7979
{

sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -224,23 +224,29 @@ static String getBufOffset(final Token token)
224224
return "offset";
225225
}
226226

227-
static List<String> getOptionalPrimitiveFields(final List<Token> tokens)
227+
static List<String> getNullifyOptionalFields(final List<Token> tokens)
228228
{
229-
final List<String> optionalPrimitiveFields = new ArrayList<>();
229+
final List<String> optionalFields = new ArrayList<>();
230230
Generators.forEachField(
231231
tokens,
232232
(fieldToken, typeToken) ->
233233
{
234-
final String name = fieldToken.name();
235-
if (Objects.requireNonNull(typeToken.signal()) == Signal.ENCODING)
234+
final String fieldName = formatFunctionName(fieldToken.name());
235+
final Signal signal = Objects.requireNonNull(typeToken.signal());
236+
if (signal == Signal.ENCODING)
236237
{
237238
if (isOptionalPrimitiveScalar(typeToken))
238239
{
239-
optionalPrimitiveFields.add(formatFunctionName(name));
240+
optionalFields.add(fieldName);
240241
}
241242
}
243+
else if (signal == BEGIN_ENUM &&
244+
isOptionalEnum(typeToken, fieldToken))
245+
{
246+
optionalFields.add(fieldName);
247+
}
242248
});
243-
return optionalPrimitiveFields;
249+
return optionalFields;
244250
}
245251

246252
static void generateEncoderFields(
@@ -266,6 +272,10 @@ static void generateEncoderFields(
266272
break;
267273
case BEGIN_ENUM:
268274
generateEnumEncoder(sb, level, fieldToken, typeToken, name);
275+
if (isOptionalEnum(typeToken, fieldToken))
276+
{
277+
generateOptionalEnumEncoder(sb, level, typeToken, name);
278+
}
269279
break;
270280
case BEGIN_SET:
271281
generateBitSetEncoder(sb, level, typeToken, name);
@@ -287,12 +297,12 @@ static void generateEncoderFields(
287297
static void generateNullifyOptionalFieldsMethod(
288298
final Appendable out,
289299
final int level,
290-
final List<String> optionalPrimitiveFields) throws IOException
300+
final List<String> optionalFields) throws IOException
291301
{
292-
indent(out, level, "/// Set all optional primitive scalar fields to their null values.\n");
302+
indent(out, level, "/// Set all optional fields to their null values.\n");
293303
indent(out, level, "#[inline]\n");
294304
indent(out, level, "pub fn nullify_optional_fields(&mut self) -> &mut Self {\n");
295-
for (final String field : optionalPrimitiveFields)
305+
for (final String field : optionalFields)
296306
{
297307
indent(out, level + 1, "self.%s_opt(None);\n", field);
298308
}
@@ -305,6 +315,18 @@ private static boolean isOptionalPrimitiveScalar(final Token typeToken)
305315
return typeToken.arrayLength() == 1 && typeToken.encoding().presence() == Encoding.Presence.OPTIONAL;
306316
}
307317

318+
private static boolean isOptionalEnum(final Token typeToken, final Token fieldToken)
319+
{
320+
return typeToken.signal() == BEGIN_ENUM && fieldToken.isOptionalEncoding();
321+
}
322+
323+
private static String rustEnumType(final Token typeToken)
324+
{
325+
final String referencedName = typeToken.referencedName();
326+
final String enumTypeName = referencedName == null ? typeToken.name() : referencedName;
327+
return format("%s::%s", toLowerSnakeCase(enumTypeName), formatStructName(enumTypeName));
328+
}
329+
308330
static void generateEncoderGroups(
309331
final StringBuilder sb,
310332
final List<Token> tokens,
@@ -642,10 +664,7 @@ private static void generateEnumEncoder(
642664
final Token typeToken,
643665
final String name) throws IOException
644666
{
645-
final String referencedName = typeToken.referencedName();
646-
final String enumType = format("%s::%s",
647-
toLowerSnakeCase(referencedName == null ? typeToken.name() : referencedName),
648-
formatStructName(referencedName == null ? typeToken.name() : referencedName));
667+
final String enumType = rustEnumType(typeToken);
649668

650669
if (fieldToken.isConstantEncoding())
651670
{
@@ -667,6 +686,27 @@ private static void generateEnumEncoder(
667686
}
668687
}
669688

689+
private static void generateOptionalEnumEncoder(
690+
final StringBuilder sb,
691+
final int level,
692+
final Token typeToken,
693+
final String name) throws IOException
694+
{
695+
final String functionName = formatFunctionName(name);
696+
final String enumType = rustEnumType(typeToken);
697+
698+
indent(sb, level, "/// optional enum field '%s'\n", name);
699+
generateRustDoc(sb, level, typeToken, typeToken.encoding());
700+
indent(sb, level, "/// Set to `None` to encode the field null value.\n");
701+
indent(sb, level, "#[inline]\n");
702+
indent(sb, level, "pub fn %s_opt(&mut self, value: Option<%s>) {\n", functionName, enumType);
703+
indent(sb, level + 1, "match value {\n");
704+
indent(sb, level + 2, "Some(value) => self.%s(value),\n", functionName);
705+
indent(sb, level + 2, "None => self.%s(%s::NullVal),\n", functionName, enumType);
706+
indent(sb, level + 1, "}\n");
707+
indent(sb, level, "}\n\n");
708+
}
709+
670710
private static void generateBitSetEncoder(
671711
final StringBuilder sb,
672712
final int level,
@@ -1341,7 +1381,7 @@ private static void generateSingleBitSet(
13411381
static void appendImplEncoderTrait(
13421382
final Appendable out,
13431383
final String typeName,
1344-
final List<String> optionalPrimitiveFields) throws IOException
1384+
final List<String> optionalFields) throws IOException
13451385
{
13461386
indent(out, 1, "impl<%s> %s for %s {\n", BUF_LIFETIME, withBufLifetime("Writer"), withBufLifetime(typeName));
13471387
indent(out, 2, "#[inline]\n");
@@ -1361,13 +1401,13 @@ static void appendImplEncoderTrait(
13611401
indent(out, 3, "self.limit = limit;\n");
13621402
indent(out, 2, "}\n");
13631403

1364-
if (!optionalPrimitiveFields.isEmpty())
1404+
if (!optionalFields.isEmpty())
13651405
{
13661406
indent(out, 0, "\n");
1367-
indent(out, 2, "/// Set all optional primitive scalar fields to their 'null' values.\n");
1407+
indent(out, 2, "/// Set all optional fields to their 'null' values.\n");
13681408
indent(out, 2, "#[inline]\n");
13691409
indent(out, 2, "fn nullify_optional_fields(&mut self) -> &mut Self {\n");
1370-
for (final String field : optionalPrimitiveFields)
1410+
for (final String field : optionalFields)
13711411
{
13721412
indent(out, 3, "self.%s_opt(None);\n", field);
13731413
}
@@ -1646,7 +1686,7 @@ static void appendImplEncoderForComposite(
16461686
final Appendable out,
16471687
final int level,
16481688
final String name,
1649-
final List<String> optionalPrimitiveFields) throws IOException
1689+
final List<String> optionalFields) throws IOException
16501690
{
16511691
appendImplWriterForComposite(out, level, name);
16521692

@@ -1662,13 +1702,13 @@ static void appendImplEncoderForComposite(
16621702
indent(out, level + 2, "self.parent.as_mut().expect(\"parent missing\").set_limit(limit);\n");
16631703
indent(out, level + 1, "}\n");
16641704

1665-
if (!optionalPrimitiveFields.isEmpty())
1705+
if (!optionalFields.isEmpty())
16661706
{
16671707
indent(out, 0, "\n");
1668-
indent(out, 2, "/// Set all optional primitive scalar fields to their 'null' values.\n");
1708+
indent(out, 2, "/// Set all optional fields to their 'null' values.\n");
16691709
indent(out, 2, "#[inline]\n");
16701710
indent(out, 2, "fn nullify_optional_fields(&mut self) -> &mut Self {\n");
1671-
for (final String field : optionalPrimitiveFields)
1711+
for (final String field : optionalFields)
16721712
{
16731713
indent(out, 3, "self.%s_opt(None);\n", field);
16741714
}
@@ -1760,7 +1800,7 @@ private static void generateCompositeEncoder(
17601800
indent(out, 3, "self.parent.take().ok_or(SbeErr::ParentNotSet)\n");
17611801
indent(out, 2, "}\n\n");
17621802

1763-
final List<String> optionalPrimitiveFields = new ArrayList<>();
1803+
final List<String> optionalFields = new ArrayList<>();
17641804
for (int i = 1, end = tokens.size() - 1; i < end; )
17651805
{
17661806
final Token encodingToken = tokens.get(i);
@@ -1774,11 +1814,16 @@ private static void generateCompositeEncoder(
17741814
{
17751815
generateOptionalPrimitiveEncoder(sb, 2, encodingToken, encodingToken.name(),
17761816
encodingToken.encoding());
1777-
optionalPrimitiveFields.add(formatFunctionName(encodingToken.name()));
1817+
optionalFields.add(formatFunctionName(encodingToken.name()));
17781818
}
17791819
break;
17801820
case BEGIN_ENUM:
17811821
generateEnumEncoder(sb, 2, encodingToken, encodingToken, encodingToken.name());
1822+
if (isOptionalEnum(encodingToken, encodingToken))
1823+
{
1824+
generateOptionalEnumEncoder(sb, 2, encodingToken, encodingToken.name());
1825+
optionalFields.add(formatFunctionName(encodingToken.name()));
1826+
}
17821827
break;
17831828
case BEGIN_SET:
17841829
generateBitSetEncoder(sb, 2, encodingToken, encodingToken.name());
@@ -1794,7 +1839,7 @@ private static void generateCompositeEncoder(
17941839
i += encodingToken.componentTokenCount();
17951840
}
17961841

1797-
generateNullifyOptionalFieldsMethod(out, 2, optionalPrimitiveFields);
1842+
generateNullifyOptionalFieldsMethod(out, 2, optionalFields);
17981843
indent(out, 1, "}\n"); // end impl
17991844
indent(out, 0, "} // end encoder mod \n");
18001845
}

sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/SubGroup.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ void generateEncoder(
7272
indent(sb, level, "initial_limit: usize,\n");
7373
indent(sb, level - 1, "}\n\n");
7474

75-
final List<String> optionalPrimitiveFields = RustGenerator.getOptionalPrimitiveFields(fields);
76-
RustGenerator.appendImplEncoderForComposite(sb, level - 1, name, optionalPrimitiveFields);
75+
final List<String> optionalFields = RustGenerator.getNullifyOptionalFields(fields);
76+
RustGenerator.appendImplEncoderForComposite(sb, level - 1, name, optionalFields);
7777

7878
// define impl...
7979
indent(sb, level - 1, "impl<'a, P> %s<P> where P: Encoder<'a> + Default {\n", name);

0 commit comments

Comments
 (0)