@@ -224,23 +224,29 @@ static String getBufOffset(final Token token)
224224 return "offset" ;
225225 }
226226
227- static List <String > getOptionalPrimitiveFields (final List <Token > tokens )
227+ static List <String > getNullifyOptionalFields (final List <Token > tokens )
228228 {
229- final List <String > optionalPrimitiveFields = new ArrayList <>();
229+ final List <String > optionalFields = new ArrayList <>();
230230 Generators .forEachField (
231231 tokens ,
232232 (fieldToken , typeToken ) ->
233233 {
234- final String name = fieldToken .name ();
235- if (Objects .requireNonNull (typeToken .signal ()) == Signal .ENCODING )
234+ final String fieldName = formatFunctionName (fieldToken .name ());
235+ final Signal signal = Objects .requireNonNull (typeToken .signal ());
236+ if (signal == Signal .ENCODING )
236237 {
237238 if (isOptionalPrimitiveScalar (typeToken ))
238239 {
239- optionalPrimitiveFields .add (formatFunctionName ( name ) );
240+ optionalFields .add (fieldName );
240241 }
241242 }
243+ else if (signal == BEGIN_ENUM &&
244+ isOptionalEnum (typeToken , fieldToken ))
245+ {
246+ optionalFields .add (fieldName );
247+ }
242248 });
243- return optionalPrimitiveFields ;
249+ return optionalFields ;
244250 }
245251
246252 static void generateEncoderFields (
@@ -266,6 +272,10 @@ static void generateEncoderFields(
266272 break ;
267273 case BEGIN_ENUM :
268274 generateEnumEncoder (sb , level , fieldToken , typeToken , name );
275+ if (isOptionalEnum (typeToken , fieldToken ))
276+ {
277+ generateOptionalEnumEncoder (sb , level , typeToken , name );
278+ }
269279 break ;
270280 case BEGIN_SET :
271281 generateBitSetEncoder (sb , level , typeToken , name );
@@ -287,12 +297,12 @@ static void generateEncoderFields(
287297 static void generateNullifyOptionalFieldsMethod (
288298 final Appendable out ,
289299 final int level ,
290- final List <String > optionalPrimitiveFields ) throws IOException
300+ final List <String > optionalFields ) throws IOException
291301 {
292- indent (out , level , "/// Set all optional primitive scalar fields to their null values.\n " );
302+ indent (out , level , "/// Set all optional fields to their null values.\n " );
293303 indent (out , level , "#[inline]\n " );
294304 indent (out , level , "pub fn nullify_optional_fields(&mut self) -> &mut Self {\n " );
295- for (final String field : optionalPrimitiveFields )
305+ for (final String field : optionalFields )
296306 {
297307 indent (out , level + 1 , "self.%s_opt(None);\n " , field );
298308 }
@@ -305,6 +315,18 @@ private static boolean isOptionalPrimitiveScalar(final Token typeToken)
305315 return typeToken .arrayLength () == 1 && typeToken .encoding ().presence () == Encoding .Presence .OPTIONAL ;
306316 }
307317
318+ private static boolean isOptionalEnum (final Token typeToken , final Token fieldToken )
319+ {
320+ return typeToken .signal () == BEGIN_ENUM && fieldToken .isOptionalEncoding ();
321+ }
322+
323+ private static String rustEnumType (final Token typeToken )
324+ {
325+ final String referencedName = typeToken .referencedName ();
326+ final String enumTypeName = referencedName == null ? typeToken .name () : referencedName ;
327+ return format ("%s::%s" , toLowerSnakeCase (enumTypeName ), formatStructName (enumTypeName ));
328+ }
329+
308330 static void generateEncoderGroups (
309331 final StringBuilder sb ,
310332 final List <Token > tokens ,
@@ -642,10 +664,7 @@ private static void generateEnumEncoder(
642664 final Token typeToken ,
643665 final String name ) throws IOException
644666 {
645- final String referencedName = typeToken .referencedName ();
646- final String enumType = format ("%s::%s" ,
647- toLowerSnakeCase (referencedName == null ? typeToken .name () : referencedName ),
648- formatStructName (referencedName == null ? typeToken .name () : referencedName ));
667+ final String enumType = rustEnumType (typeToken );
649668
650669 if (fieldToken .isConstantEncoding ())
651670 {
@@ -667,6 +686,27 @@ private static void generateEnumEncoder(
667686 }
668687 }
669688
689+ private static void generateOptionalEnumEncoder (
690+ final StringBuilder sb ,
691+ final int level ,
692+ final Token typeToken ,
693+ final String name ) throws IOException
694+ {
695+ final String functionName = formatFunctionName (name );
696+ final String enumType = rustEnumType (typeToken );
697+
698+ indent (sb , level , "/// optional enum field '%s'\n " , name );
699+ generateRustDoc (sb , level , typeToken , typeToken .encoding ());
700+ indent (sb , level , "/// Set to `None` to encode the field null value.\n " );
701+ indent (sb , level , "#[inline]\n " );
702+ indent (sb , level , "pub fn %s_opt(&mut self, value: Option<%s>) {\n " , functionName , enumType );
703+ indent (sb , level + 1 , "match value {\n " );
704+ indent (sb , level + 2 , "Some(value) => self.%s(value),\n " , functionName );
705+ indent (sb , level + 2 , "None => self.%s(%s::NullVal),\n " , functionName , enumType );
706+ indent (sb , level + 1 , "}\n " );
707+ indent (sb , level , "}\n \n " );
708+ }
709+
670710 private static void generateBitSetEncoder (
671711 final StringBuilder sb ,
672712 final int level ,
@@ -1341,7 +1381,7 @@ private static void generateSingleBitSet(
13411381 static void appendImplEncoderTrait (
13421382 final Appendable out ,
13431383 final String typeName ,
1344- final List <String > optionalPrimitiveFields ) throws IOException
1384+ final List <String > optionalFields ) throws IOException
13451385 {
13461386 indent (out , 1 , "impl<%s> %s for %s {\n " , BUF_LIFETIME , withBufLifetime ("Writer" ), withBufLifetime (typeName ));
13471387 indent (out , 2 , "#[inline]\n " );
@@ -1361,13 +1401,13 @@ static void appendImplEncoderTrait(
13611401 indent (out , 3 , "self.limit = limit;\n " );
13621402 indent (out , 2 , "}\n " );
13631403
1364- if (!optionalPrimitiveFields .isEmpty ())
1404+ if (!optionalFields .isEmpty ())
13651405 {
13661406 indent (out , 0 , "\n " );
1367- indent (out , 2 , "/// Set all optional primitive scalar fields to their 'null' values.\n " );
1407+ indent (out , 2 , "/// Set all optional fields to their 'null' values.\n " );
13681408 indent (out , 2 , "#[inline]\n " );
13691409 indent (out , 2 , "fn nullify_optional_fields(&mut self) -> &mut Self {\n " );
1370- for (final String field : optionalPrimitiveFields )
1410+ for (final String field : optionalFields )
13711411 {
13721412 indent (out , 3 , "self.%s_opt(None);\n " , field );
13731413 }
@@ -1646,7 +1686,7 @@ static void appendImplEncoderForComposite(
16461686 final Appendable out ,
16471687 final int level ,
16481688 final String name ,
1649- final List <String > optionalPrimitiveFields ) throws IOException
1689+ final List <String > optionalFields ) throws IOException
16501690 {
16511691 appendImplWriterForComposite (out , level , name );
16521692
@@ -1662,13 +1702,13 @@ static void appendImplEncoderForComposite(
16621702 indent (out , level + 2 , "self.parent.as_mut().expect(\" parent missing\" ).set_limit(limit);\n " );
16631703 indent (out , level + 1 , "}\n " );
16641704
1665- if (!optionalPrimitiveFields .isEmpty ())
1705+ if (!optionalFields .isEmpty ())
16661706 {
16671707 indent (out , 0 , "\n " );
1668- indent (out , 2 , "/// Set all optional primitive scalar fields to their 'null' values.\n " );
1708+ indent (out , 2 , "/// Set all optional fields to their 'null' values.\n " );
16691709 indent (out , 2 , "#[inline]\n " );
16701710 indent (out , 2 , "fn nullify_optional_fields(&mut self) -> &mut Self {\n " );
1671- for (final String field : optionalPrimitiveFields )
1711+ for (final String field : optionalFields )
16721712 {
16731713 indent (out , 3 , "self.%s_opt(None);\n " , field );
16741714 }
@@ -1760,7 +1800,7 @@ private static void generateCompositeEncoder(
17601800 indent (out , 3 , "self.parent.take().ok_or(SbeErr::ParentNotSet)\n " );
17611801 indent (out , 2 , "}\n \n " );
17621802
1763- final List <String > optionalPrimitiveFields = new ArrayList <>();
1803+ final List <String > optionalFields = new ArrayList <>();
17641804 for (int i = 1 , end = tokens .size () - 1 ; i < end ; )
17651805 {
17661806 final Token encodingToken = tokens .get (i );
@@ -1774,11 +1814,16 @@ private static void generateCompositeEncoder(
17741814 {
17751815 generateOptionalPrimitiveEncoder (sb , 2 , encodingToken , encodingToken .name (),
17761816 encodingToken .encoding ());
1777- optionalPrimitiveFields .add (formatFunctionName (encodingToken .name ()));
1817+ optionalFields .add (formatFunctionName (encodingToken .name ()));
17781818 }
17791819 break ;
17801820 case BEGIN_ENUM :
17811821 generateEnumEncoder (sb , 2 , encodingToken , encodingToken , encodingToken .name ());
1822+ if (isOptionalEnum (encodingToken , encodingToken ))
1823+ {
1824+ generateOptionalEnumEncoder (sb , 2 , encodingToken , encodingToken .name ());
1825+ optionalFields .add (formatFunctionName (encodingToken .name ()));
1826+ }
17821827 break ;
17831828 case BEGIN_SET :
17841829 generateBitSetEncoder (sb , 2 , encodingToken , encodingToken .name ());
@@ -1794,7 +1839,7 @@ private static void generateCompositeEncoder(
17941839 i += encodingToken .componentTokenCount ();
17951840 }
17961841
1797- generateNullifyOptionalFieldsMethod (out , 2 , optionalPrimitiveFields );
1842+ generateNullifyOptionalFieldsMethod (out , 2 , optionalFields );
17981843 indent (out , 1 , "}\n " ); // end impl
17991844 indent (out , 0 , "} // end encoder mod \n " );
18001845 }
0 commit comments