diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java index a3cabcc660ba2..985e3d7f76c13 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java @@ -247,6 +247,27 @@ public static DataType extractFromMethodReturnType( // Methods that extract a data type from a JVM Class with prior logical information // -------------------------------------------------------------------------------------------- + /** + * Extracts a data type from a reflective {@link Field}, honoring a {@link DataTypeHint} on the + * field if present and otherwise falling back to reflective extraction with default templates. + */ + public static DataType extractFromField(DataTypeFactory typeFactory, Field field) { + final DataTypeHint hint = field.getAnnotation(DataTypeHint.class); + final DataTypeTemplate template = + hint != null + ? DataTypeTemplate.fromDefaults() + .mergeWithInnerAnnotation(typeFactory, hint) + : DataTypeTemplate.fromDefaults(); + return extractDataTypeWithClassContext( + typeFactory, + template, + field.getDeclaringClass(), + field.getGenericType(), + String.format( + " in field '%s' of class '%s'", + field.getName(), field.getDeclaringClass().getName())); + } + public static DataType extractFromStructuredClass( DataTypeFactory typeFactory, Class implementationClass) { final DataType dataType = diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/TypeInfoDataTypeConverter.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/TypeInfoDataTypeConverter.java index 7d8163aba531e..4a46ae4c54762 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/TypeInfoDataTypeConverter.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/utils/TypeInfoDataTypeConverter.java @@ -33,10 +33,12 @@ import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; +import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.DataTypeQueryable; +import org.apache.flink.table.types.extraction.DataTypeExtractor; import org.apache.flink.table.types.extraction.ExtractionUtils; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RawType; @@ -54,6 +56,7 @@ import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -271,13 +274,37 @@ private static DataType convertToStructuredType( final String[] fieldNames = compositeType.getFieldNames(); final Class typeClass = compositeType.getTypeClass(); + final Map hintedPojoFields; + if (compositeType instanceof PojoTypeInfo) { + final PojoTypeInfo pojoTypeInfo = (PojoTypeInfo) compositeType; + hintedPojoFields = new HashMap<>(); + for (int pos = 0; pos < arity; pos++) { + final Field field = pojoTypeInfo.getPojoFieldAt(pos).getField(); + if (field.isAnnotationPresent(DataTypeHint.class)) { + hintedPojoFields.put(field.getName(), field); + } + } + } else { + hintedPojoFields = Collections.emptyMap(); + } + final Map fieldDataTypes = new LinkedHashMap<>(); IntStream.range(0, arity) .forEachOrdered( - pos -> + pos -> { + final String name = fieldNames[pos]; + final Field hintedField = hintedPojoFields.get(name); + if (hintedField != null) { + fieldDataTypes.put( + name, + DataTypeExtractor.extractFromField( + dataTypeFactory, hintedField)); + } else { fieldDataTypes.put( - fieldNames[pos], - toDataType(dataTypeFactory, compositeType.getTypeAt(pos)))); + name, + toDataType(dataTypeFactory, compositeType.getTypeAt(pos))); + } + }); final List fieldNamesReordered; final boolean isNullable; @@ -297,6 +324,10 @@ private static DataType convertToStructuredType( // therefore we need to check the reflective field for more details fieldDataTypes.replaceAll( (name, dataType) -> { + // hinted fields already resolved through DataTypeExtractor + if (hintedPojoFields.containsKey(name)) { + return dataType; + } final Class fieldClass = pojoFields.stream() .filter(f -> f.getName().equals(name)) diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/TypeInfoDataTypeConverterTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/TypeInfoDataTypeConverterTest.java index c22ae68b3f173..ab4e8eb039a53 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/TypeInfoDataTypeConverterTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/TypeInfoDataTypeConverterTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.types.utils.DataTypeFactoryMock; import org.apache.flink.table.types.utils.TypeInfoDataTypeConverter; @@ -32,6 +33,7 @@ import org.junit.jupiter.params.provider.MethodSource; import java.time.DayOfWeek; +import java.util.HashMap; import java.util.Optional; import java.util.stream.Stream; @@ -111,7 +113,17 @@ private static Stream testData() { .lookupExpects(DayOfWeek.class) .expectDataType(dummyRaw(DayOfWeek.class)), TestSpec.forType(Types.VARIANT).expectDataType(DataTypes.VARIANT()), - TestSpec.forType(Types.BITMAP).expectDataType(DataTypes.BITMAP())); + TestSpec.forType(Types.BITMAP).expectDataType(DataTypes.BITMAP()), + TestSpec.forType(Types.POJO(PojoWithHintedMapField.class)) + .rawFallback(HashMap.class) + .expectDataType( + DataTypes.STRUCTURED( + PojoWithHintedMapField.class, + DataTypes.FIELD( + "headers", + DataTypes.MAP(DataTypes.STRING(), DataTypes.BYTES()) + .bridgedTo(HashMap.class)), + DataTypes.FIELD("id", DataTypes.STRING())))); } @ParameterizedTest(name = "{index}: {0}") @@ -148,6 +160,11 @@ TestSpec lookupExpects(Class lookupClass) { return this; } + TestSpec rawFallback(Class rawClass) { + typeFactory.dataType = Optional.of(dummyRaw(rawClass)); + return this; + } + TestSpec expectDataType(DataType expectedDataType) { this.expectedDataType = expectedDataType; return this; @@ -183,6 +200,19 @@ public PojoWithFieldOrder(String name, boolean gender, int age) { } } + /** POJO with a field carrying a {@link DataTypeHint} that overrides reflective extraction. */ + public static class PojoWithHintedMapField { + + public String id; + + @DataTypeHint("MAP") + public HashMap headers; + + public PojoWithHintedMapField() { + // default constructor + } + } + /** POJO that defines a field order via an additional constructor. */ @SuppressWarnings("unused") public static class PojoWithDefaultFieldOrder {