diff --git a/.dockerignore b/.dockerignore index 50d68018327..400575544c0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,7 @@ **/*.log **/target !gremlin-server/target/apache-tinkerpop-gremlin-server-* +!gremlin-server/target/gremlin-server-*-tests.jar !gremlin-console/target/apache-tinkerpop-gremlin-console-* *.iml .idea diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc index e312807286e..de5df857c03 100644 --- a/CHANGELOG.asciidoc +++ b/CHANGELOG.asciidoc @@ -27,6 +27,7 @@ image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima * Added typed numeric wrappers and `preciseNumbers` connection option to `gremlin-javascript` for explicit control over numeric type serialization and deserialization. * Added `NextN(n)` to `Traversal` in `gremlin-go` for batched result iteration, providing API parity with `next(n)` in the Java, Python, and .NET GLVs, and updated the Go translators in `gremlin-core` and `gremlin-javascript` to emit `NextN(n)` for the batched form. +* Added Provider Defined Types (PDT) support — graph providers can define custom types via `@ProviderDefined` annotation that serialize/deserialize seamlessly across all GLVs without driver-side configuration. Replaces TP3 custom type mechanism. * Added Gremlator, a single page web application, that translates Gremlin into various programming languages like Javascript and Python. * Removed `uuid` dependency from `gremlin-javascript` in favor of the built-in `globalThis.crypto.randomUUID()`. * Added streaming HTTP response support to `gremlin-driver` for incremental result deserialization over GraphBinary. diff --git a/docker/gremlin-test-server/Dockerfile b/docker/gremlin-test-server/Dockerfile index 8ac853c1d1a..3740499dcc7 100644 --- a/docker/gremlin-test-server/Dockerfile +++ b/docker/gremlin-test-server/Dockerfile @@ -24,6 +24,7 @@ USER root RUN mkdir -p /opt WORKDIR /opt COPY gremlin-server/src/test /opt/test/ +COPY gremlin-server/target/gremlin-server-*-tests.jar /opt/gremlin-server/lib/ COPY docker/gremlin-server/docker-entrypoint.sh docker/gremlin-server/*.yaml docker/gremlin-server/*.conf /opt/ RUN chmod 755 /opt/docker-entrypoint.sh diff --git a/docs/src/dev/provider/index.asciidoc b/docs/src/dev/provider/index.asciidoc index 7654807a90f..7de043bc9da 100644 --- a/docs/src/dev/provider/index.asciidoc +++ b/docs/src/dev/provider/index.asciidoc @@ -1334,6 +1334,213 @@ can be used as a reference on how these files can be used and its link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/structure/io/Model.java[model] shows the Java representation of those files. +[[provider-defined-types]] +=== Provider Defined Types (PDT) + +Provider Defined Types allow graph providers to expose custom types that drivers can serialize and deserialize without +manual configuration on the client side. A provider annotates a class (or registers an adapter for a class it doesn't +own), and the type flows through the wire protocol automatically. Clients receive PDT values as structured objects they +can use directly or hydrate into language-native types. + +==== Basic Usage + +Annotate a class with `@ProviderDefined` from the `org.apache.tinkerpop.gremlin.structure.io.pdt` package: + +[source,java] +---- +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; + +@ProviderDefined(name = "mygraph:Point") +public class Point { + public double x; + public double y; + + public Point(double x, double y) { + this.x = x; + this.y = y; + } +} +---- + +The `name` attribute is a unique identifier for the type. It is strongly recommended to namespace type names using +your graph's identifier as a prefix (e.g. `"mygraph:Point"`). This avoids collisions when clients interact with +multiple providers and makes the origin of a type immediately clear. By default, all fields are included. Use +`includedFields` or `excludedFields` to control which fields are serialized: + +[source,java] +---- +@ProviderDefined(name = "mygraph:Point", includedFields = {"x", "y"}) +public class Point { ... } + +// or exclude specific fields +@ProviderDefined(name = "mygraph:Person", excludedFields = {"internalId"}) +public class Person { ... } +---- + +NOTE: For annotation-based round-trip hydration (see <>), an annotated class must expose a no-arg +constructor and the mapped fields must be directly settable (e.g. public fields). Classes that cannot meet these +requirements — for example those with immutable `final` fields or no default constructor — should instead use a +`ProviderDefinedTypeAdapter` (see <>), which gives full control over construction. + +==== Nested Types + +PDT supports nested custom types. Each nested type must also be annotated: + +[source,java] +---- +@ProviderDefined(name = "mygraph:Address") +public class Address { + public String street; + public String city; +} + +@ProviderDefined(name = "mygraph:Person") +public class Person { + public String name; + public Address address; +} +---- + +When serialized, the `address` field is itself encoded as a PDT value. + +[[adapter-for-types-you-don-t-own]] +==== Adapter for Types You Don't Own + +For classes you cannot annotate (e.g. `java.awt.Color`), implement `ProviderDefinedTypeAdapter`: + +[source,java] +---- +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeAdapter; + +public class ColorAdapter implements ProviderDefinedTypeAdapter { + + @Override + public String typeName() { return "mygraph:Color"; } + + @Override + public Class targetClass() { return java.awt.Color.class; } + + @Override + public Map toProperties(java.awt.Color color) { + return Map.of("r", color.getRed(), "g", color.getGreen(), + "b", color.getBlue(), "a", color.getAlpha()); + } + + @Override + public java.awt.Color fromProperties(Map fields) { + return new java.awt.Color((int) fields.get("r"), (int) fields.get("g"), + (int) fields.get("b"), (int) fields.get("a")); + } +} +---- + +[[round-trip-support]] +==== Round-Trip Support (Dehydration and Hydration) + +There is an important distinction between *dehydration* (serializing a type for sending) and *hydration* (deserializing +a received PDT back into a language-native type). + +*Dehydration* is handled automatically for `@ProviderDefined`-annotated classes and adapter-registered types. When a +user passes an annotated object into a Gremlin traversal or script, TinkerPop converts it to a PDT on the wire +without any extra configuration. + +*Hydration* — reconstructing an incoming PDT back into the original typed object — requires the driver to know which +class corresponds to a given PDT name. Without this mapping, the driver will return a generic `ProviderDefinedType` +object. To enable automatic round-trip hydration, providers must expose a pre-configured `ProviderDefinedTypeRegistry` +to users. How that registry is populated differs by language: + +===== Java + +Register annotated classes explicitly with the registry. `register(Class...)` inspects the `@ProviderDefined` +annotation to derive the type name and field mapping automatically: + +[source,java] +---- +ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.build(); +registry.register(Point.class, Address.class, Person.class); +---- + +Adapter types (for classes you don't own) are discovered automatically via `ServiceLoader` when using +`ProviderDefinedTypeRegistry.build()`. Register them by adding a file at: + +---- +META-INF/services/org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeAdapter +---- + +with the fully qualified class name of each adapter: + +---- +com.example.graph.ColorAdapter +---- + +===== Python + +Hydration is fully automatic for `@provider_defined`-decorated classes. The decorator registers the class at +definition time (import time), so any annotated type round-trips without any additional setup. + +===== .NET + +`[ProviderDefined]`-annotated types are discovered automatically. Calling `ProviderDefinedTypeRegistry.Build()` +scans all loaded assemblies for `[ProviderDefined]`-annotated types and registers them for hydration. No extra +configuration is needed — providers simply annotate their types and users call `Build()` to create the registry. + +===== JavaScript + +Register hydration adapters explicitly on a `ProviderDefinedTypeRegistry` instance, then pass it to the connection: + +[source,javascript] +---- +const registry = new ProviderDefinedTypeRegistry(); +registry.register('mygraph:Point', { + serialize: (obj) => ({ x: obj.x, y: obj.y }), + deserialize: (props) => new Point(props.x, props.y) +}, Point); +---- + +===== Go + +Register types on a `PDTRegistry` instance. Go supports either reflection-based registration (using `pdt` struct +tags) or explicit function registration: + +[source,go] +---- +registry := NewPDTRegistry() +registry.RegisterType("mygraph:Point", reflect.TypeOf(Point{})) +---- + +===== Provider Factory Pattern + +Regardless of language, the recommended pattern is for providers to expose a factory method that returns a +pre-configured `ProviderDefinedTypeRegistry`. This shields end users from needing to know which types exist or how +the registry is populated: + +[source,java] +---- +// In the provider's client library +public class MyGraphTypeRegistry { + public static ProviderDefinedTypeRegistry build() { + ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.build(); // discovers ServiceLoader adapters + registry.register(Point.class, Address.class, Person.class); // registers annotated types + return registry; + } +} +---- + +End users configure their connection in one line: + +[source,java] +---- +DriverRemoteConnection conn = DriverRemoteConnection.using(cluster); +conn.setPdtRegistry(MyGraphTypeRegistry.build()); +GraphTraversalSource g = traversal().with(conn); +---- + +With this in place, `Point` objects round-trip transparently in both directions — the annotation handles outbound +serialization and the registry handles inbound reconstruction. + +For driver users consuming PDTs, see the <> reference documentation for +each language driver. + [[gremlin-plugins]] == Gremlin Plugins diff --git a/docs/src/reference/gremlin-variants.asciidoc b/docs/src/reference/gremlin-variants.asciidoc index 521a7017bf7..2583e93ebf1 100644 --- a/docs/src/reference/gremlin-variants.asciidoc +++ b/docs/src/reference/gremlin-variants.asciidoc @@ -270,6 +270,7 @@ can be passed to the `NewClient` or `NewDriverRemoteConnection` functions as con More details can be found in provider docs link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_graph_driver_provider_requirements[here].|true |RequestInterceptors |Functions that modify HTTP requests before sending. Used for authentication and custom headers. |empty +|PDTRegistry |A `*PDTRegistry` for hydrating and dehydrating <>. |`nil` |========================================================= [[gremlin-go-strategies]] @@ -615,6 +616,48 @@ go run basic_gremlin.go go run modern_traversals.go ---- +[[gremlin-go-pdt]] +=== Provider Defined Types + +Provider Defined Types (PDTs) allow graph providers to expose custom types through the driver. PDT values arrive as +`*ProviderDefinedType` structs containing a `Name` and `Properties` map without any configuration. +Consult your graph provider's documentation for the list of PDTs they support. + +[source,go] +---- +results, err := g.V().Has("location").Values("location").ToList() +pdt := results[0].GetInterface().(*gremlingo.ProviderDefinedType) +fmt.Println(pdt.Name) // "x:Point" +fmt.Println(pdt.Properties) // map[x:1.0 y:2.0] +---- + +Working with raw `*ProviderDefinedType` values is always available. Using a `PDTRegistry` is an optional +convenience that automates conversion between PDT values and application types on both the request and response paths. + +Using a `PDTRegistry` for hydration and dehydration: + +[source,go] +---- +registry := gremlingo.NewPDTRegistry() +registry.RegisterFuncsWithType("x:Point", reflect.TypeOf(Point{}), + // hydrate: convert incoming PDT properties map to a Go type + func(props map[string]interface{}) (interface{}, error) { + return &Point{X: props["x"].(float64), Y: props["y"].(float64)}, nil + }, + // dehydrate: convert a Go type to a PDT properties map for sending + func(obj interface{}) (map[string]interface{}, error) { + p := obj.(*Point) + return map[string]interface{}{"x": p.X, "y": p.Y}, nil + }, +) + +remote, _ := gremlingo.NewDriverRemoteConnection("http://localhost:8182/gremlin", + func(settings *gremlingo.DriverRemoteConnectionSettings) { + settings.PDTRegistry = registry + }) +g := gremlingo.Traversal_().With(remote) +---- + [[gremlin-groovy]] == Gremlin-Groovy @@ -1504,6 +1547,72 @@ java -cp target/run-examples-shaded.jar examples.BasicGremlin java -cp target/run-examples-shaded.jar examples.ModernTraversals ---- +[[gremlin-java-pdt]] +=== Provider Defined Types + +Provider Defined Types (PDTs) allow graph providers to expose custom types through the driver. PDT values arrive as +`ProviderDefinedType` objects containing a name and properties map without any configuration. +Consult your graph provider's documentation for the list of PDTs they support. + +Receiving a raw PDT: + +[source,java] +---- +ProviderDefinedType pdt = (ProviderDefinedType) g.V().has("location").values("location").next(); +String typeName = pdt.getName(); // "x:Point" +Map props = pdt.getProperties(); // {x: 1.0, y: 2.0} +---- + +Working with raw `ProviderDefinedType` objects is always available. The following two approaches are optional +conveniences that automate conversion between PDT values and application types on both the request and response paths. + +Using a `ProviderDefinedTypeRegistry` for hydration and dehydration: +---- +public class PointAdapter implements ProviderDefinedTypeAdapter { + @Override public String typeName() { return "x:Point"; } + @Override public Class targetClass() { return Point.class; } + @Override public Map toProperties(Point p) { return Map.of("x", p.getX(), "y", p.getY()); } + @Override public Point fromProperties(Map m) { return new Point((double) m.get("x"), (double) m.get("y")); } +} +---- + +Register adapters via ServiceLoader by adding the fully qualified class name to +`META-INF/services/org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeAdapter`. The driver discovers +adapters on the classpath and automatically hydrates/dehydrates. + +For simpler cases where you own the type, annotate it directly to avoid writing an adapter: + +Annotation-based conversion with `@ProviderDefined`: + +[source,java] +---- +// includedFields: only serialize the listed fields +@ProviderDefined(name = "x:Point", includedFields = {"x", "y"}) +public class Point { + private final double x; + private final double y; + private final String internalId; // not serialized + // constructor, getters... +} + +// excludedFields: serialize all fields except those listed +@ProviderDefined(name = "x:Timestamped", excludedFields = {"createdAt"}) +public class Timestamped { + private final String value; + private final long createdAt; // not serialized + // constructor, getters... +} + +// send: Point is automatically dehydrated +g.inject(new Point(1.0, 2.0, "internal")).iterate(); + +// receive: PDT is automatically hydrated back to Point +Point p = (Point) g.V().has("location").values("location").next(); +---- + +Classes annotated with `@ProviderDefined` are automatically dehydrated when passed as traversal arguments and +hydrated on deserialization without explicit registry configuration. + [[gremlin-javascript]] == Gremlin-JavaScript @@ -1682,6 +1791,7 @@ can be passed in the constructor of a new `Client` or `DriverRemoteConnection` : |options.writer |GraphBinaryWriter |The writer to use for serializing requests. |GraphBinaryWriter |options.enableUserAgentOnConnect |Boolean |Determines if a user agent header will be sent with requests. |true |options.agent |Agent |A custom `node:http` or `node:https` Agent for connection pooling or proxy configuration. |undefined +|options.pdtRegistry |ProviderDefinedTypeRegistry |A registry for hydrating and dehydrating <>. |undefined |========================================================= [[gremlin-javascript-logging]] @@ -2101,6 +2211,42 @@ node basic-gremlin.js node modern-traversals.js ---- +[[gremlin-javascript-pdt]] +=== Provider Defined Types + +Provider Defined Types (PDTs) allow graph providers to expose custom types through the driver. PDT values arrive as +`ProviderDefinedType` objects containing a `name` and `properties` map without any configuration. +Consult your graph provider's documentation for the list of PDTs they support. + +Receiving a raw PDT: + +[source,javascript] +---- +const results = await g.V().has('location').values('location').toList(); +const pdt = results[0]; +console.log(pdt.name); // "x:Point" +console.log(pdt.properties); // { x: 1.0, y: 2.0 } +---- + +Working with raw `ProviderDefinedType` objects is always available. Using a `ProviderDefinedTypeRegistry` is an +optional convenience that automates conversion between PDT values and application types on both the request and +response paths. + +[source,javascript] +---- +const { ProviderDefinedTypeRegistry } = require('gremlin'); + +const registry = new ProviderDefinedTypeRegistry(); +registry.register('x:Point', { + serialize: (point) => ({ x: point.x, y: point.y }), + deserialize: (props) => new Point(props.x, props.y) +}, Point); + +const g = traversal().with_(new DriverRemoteConnection('http://localhost:8182/gremlin', { + pdtRegistry: registry +})); +---- + anchor:gremlin-DotNet[] [[gremlin-dotnet]] == Gremlin.Net @@ -2241,6 +2387,7 @@ The following options can be passed to the `GremlinClient` constructor: |connectionSettings |The `ConnectionSettings` for the HTTP connection. |default `ConnectionSettings` |loggerFactory |An `ILoggerFactory` for logging. |`NullLoggerFactory` |interceptors |A list of `Func` that modify HTTP requests before sending. |_none_ +|pdtRegistry |A `ProviderDefinedTypeRegistry` for hydrating and dehydrating <>. |`null` |========================================================= [[gremlin-dotnet-logging]] @@ -2548,6 +2695,82 @@ dotnet run --project Connections dotnet run --project ModernTraversals ---- +[[gremlin-dotnet-pdt]] +=== Provider Defined Types + +Provider Defined Types (PDTs) allow graph providers to expose custom types through the driver. PDT values arrive as +`ProviderDefinedType` objects containing a `Name` and `Properties` dictionary without any configuration. +Consult your graph provider's documentation for the list of PDTs they support. + +Receiving a raw PDT: + +[source,csharp] +---- +var pdt = (ProviderDefinedType) g.V().Has("location").Values("location").Next(); +Console.WriteLine(pdt.Name); // "x:Point" +Console.WriteLine(pdt.Properties); // { x: 1.0, y: 2.0 } +---- + +Working with raw `ProviderDefinedType` objects is always available. The following two approaches are optional +conveniences that automate conversion between PDT values and application types on both the request and response paths. + +Using a `ProviderDefinedTypeRegistry` for hydration and dehydration: + +[source,csharp] +---- +public class PointAdapter : IProviderDefinedTypeAdapter +{ + public string TypeName => "x:Point"; + + public Point FromProperties(IReadOnlyDictionary properties) => + new Point((double)properties["x"], (double)properties["y"]); + + public IReadOnlyDictionary ToProperties(Point value) => + new Dictionary { ["x"] = value.X, ["y"] = value.Y }; +} + +var registry = new ProviderDefinedTypeRegistry(); +registry.Register(new PointAdapter()); + +using var client = new GremlinClient(new GremlinServer("localhost", 8182), pdtRegistry: registry); +---- + +The `ProviderDefinedTypeRegistry.Build()` method scans loaded assemblies for `IProviderDefinedTypeAdapter` +implementations and registers them automatically. + +For simpler cases where you own the type, annotate it directly to avoid writing an adapter: + +Attribute-based conversion with `[ProviderDefined]`: + +[source,csharp] +---- +// IncludedFields: only serialize the listed properties +[ProviderDefined(Name = "x:Point", IncludedFields = new[] { "X", "Y" })] +public class Point +{ + public double X { get; set; } + public double Y { get; set; } + public string InternalId { get; set; } // not serialized +} + +// ExcludedFields: serialize all properties except those listed +[ProviderDefined(Name = "x:Timestamped", ExcludedFields = new[] { "CreatedAt" })] +public class Timestamped +{ + public string Value { get; set; } + public long CreatedAt { get; set; } // not serialized +} + +// send: Point is automatically dehydrated +await g.Inject(new Point { X = 1.0, Y = 2.0 }).Promise(t => t.Iterate()); + +// receive: PDT is automatically hydrated back to Point +var p = (Point) await g.V().Has("location").Values("location").Promise(t => t.Next()); +---- + +Classes annotated with `[ProviderDefined]` are automatically dehydrated when passed as traversal arguments and +hydrated on deserialization without explicit registry configuration. + [[gremlin-python]] == Gremlin-Python @@ -2737,6 +2960,7 @@ can be passed to the `Client` or `DriverRemoteConnection` instance as keyword ar More details can be found in provider docs link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_graph_driver_provider_requirements[here].|True |bulk_results |Enables bulking of results on the server. |False +|pdt_registry |A `ProviderDefinedTypeRegistry` for hydrating and dehydrating <>. |`None` |========================================================= Transport options such as SSL and timeouts can be passed as keyword arguments to `Client` or @@ -3108,3 +3332,76 @@ python connections.py python basic_gremlin.py python modern_traversals.py ---- + +[[gremlin-python-pdt]] +=== Provider Defined Types + +Provider Defined Types (PDTs) allow graph providers to expose custom types through the driver. PDT values arrive as +`ProviderDefinedType` objects containing a `name` and `properties` dict without any configuration. +Consult your graph provider's documentation for the list of PDTs they support. + +Receiving a raw PDT: + +[source,python] +---- +pdt = g.V().has('location').values('location').next() +print(pdt.name) # "x:Point" +print(pdt.properties) # {'x': 1.0, 'y': 2.0} +---- + +Working with raw `ProviderDefinedType` objects is always available. The following two approaches are optional +conveniences that automate conversion between PDT values and application types on both the request and response paths. + +Using a `ProviderDefinedTypeRegistry` for hydration and dehydration: + +[source,python] +---- +from gremlin_python.structure.graph import ProviderDefinedTypeRegistry + +registry = ProviderDefinedTypeRegistry() +registry.register('x:Point', + deserialize_fn=lambda props: Point(props['x'], props['y']), + serialize_fn=lambda p: {'x': p.x, 'y': p.y}, + target_class=Point) + +g = traversal().with_(DriverRemoteConnection('http://localhost:8182/gremlin', 'g', + pdt_registry=registry)) +---- + +The `ProviderDefinedTypeRegistry.build()` class method discovers adapters via `entry_points` in `pyproject.toml` +under the `tinkerpop.pdt` group. + +For simpler cases where you own the type, the `@provider_defined` decorator enables automatic round-trip conversion +without an explicit registry: + +Decorator-based conversion with `@provider_defined`: + +[source,python] +---- +from gremlin_python.structure.graph import provider_defined + +# included_fields: only serialize the listed fields +@provider_defined(name='x:Point', included_fields=['x', 'y']) +class Point: + def __init__(self, x, y): + self.x = x + self.y = y + self._internal_id = None # not serialized + +# excluded_fields: serialize all instance fields except those listed +@provider_defined(name='x:Timestamped', excluded_fields=['created_at']) +class Timestamped: + def __init__(self, value, created_at): + self.value = value + self.created_at = created_at # not serialized + +# send: Point is automatically dehydrated +g.inject(Point(1.0, 2.0)).iterate() + +# receive: PDT is automatically hydrated back to Point +p = g.V().has('location').values('location').next() +# type(p) is Point +---- + +Objects decorated with `@provider_defined` are automatically dehydrated when passed as traversal arguments and +hydrated back into the same type when received in responses. \ No newline at end of file diff --git a/docs/src/upgrade/release-4.x.x.asciidoc b/docs/src/upgrade/release-4.x.x.asciidoc index aa14264ca0a..e0cd2e72a67 100644 --- a/docs/src/upgrade/release-4.x.x.asciidoc +++ b/docs/src/upgrade/release-4.x.x.asciidoc @@ -489,10 +489,37 @@ unwrap(toInt(29)); // 29 unwrap('hello'); // 'hello' ---- +==== Provider Defined Types + +Graph providers may now expose custom types as Provider Defined Types (PDT) (replacing the old `CustomTypeSerializer` +mechanism). The key improvement is that the default case now works out of the box — drivers deserialize PDT values +as `ProviderDefinedType` objects containing a `name` and a `properties` map without any configuration, eliminating +the serializer errors that occurred with unknown custom types in TP3. For automatic conversion between PDT values and +application-specific types, each driver supports an optional registry or annotation mechanism, which requires similar +effort to the old custom serializer approach but is entirely optional for basic usage: + +* <> +* <> +* <> +* <> +* <> + === Upgrading for Providers ==== Graph System Providers +===== Provider Defined Types + +TinkerPop 4 replaces the TP3 `CustomTypeSerializer` mechanism with Provider Defined Types (PDT). The key improvement +is that driver users receive PDT values as structured `ProviderDefinedType` objects by default, without any +configuration — eliminating the serializer errors that unknown custom types caused in TP3. Providers expose types by +annotating classes with `@ProviderDefined` or implementing `ProviderDefinedTypeAdapter` for types they don't own; +adapters are discovered automatically via ServiceLoader, requiring similar effort to the old approach but benefiting +all driver users transparently. + +See <> for full details on annotation usage, field filtering, nested types, and ServiceLoader +registration. + ==== Graph Driver Providers == TinkerPop 4.0.0-beta.2 diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/DefaultGremlinBaseVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/DefaultGremlinBaseVisitor.java index 550774cd271..760da5be7a9 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/DefaultGremlinBaseVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/DefaultGremlinBaseVisitor.java @@ -1355,6 +1355,10 @@ protected void notImplemented(final ParseTree ctx) { * {@inheritDoc} */ @Override public T visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { notImplemented(ctx); return null; } + /** + * {@inheritDoc} + */ + @Override public T visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { notImplemented(ctx); return null; } /** * {@inheritDoc} */ diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/GenericLiteralVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/GenericLiteralVisitor.java index 67dcd476652..9af6b7a3d9f 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/GenericLiteralVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/grammar/GenericLiteralVisitor.java @@ -27,6 +27,7 @@ import org.apache.tinkerpop.gremlin.process.traversal.TraversalStrategy; import org.apache.tinkerpop.gremlin.structure.T; import org.apache.tinkerpop.gremlin.structure.VertexProperty; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; import org.apache.tinkerpop.gremlin.util.DatetimeHelper; import java.math.BigDecimal; @@ -578,6 +579,22 @@ public Object visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { } } + /** + * {@inheritDoc} + */ + @Override + public Object visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + final String name = (String) visitStringLiteral(ctx.stringLiteral()); + final Map properties = new LinkedHashMap<>(); + final Map rawMap = (Map) visitGenericMapLiteral(ctx.genericMapLiteral()); + for (final Map.Entry entry : rawMap.entrySet()) { + if (!(entry.getKey() instanceof String)) + throw new IllegalArgumentException("PDT properties map must have String keys, found: " + entry.getKey().getClass().getName()); + properties.put((String) entry.getKey(), entry.getValue()); + } + return new ProviderDefinedType(name, properties); + } + /** * {@inheritDoc} */ diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/AnonymizedTranslatorVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/AnonymizedTranslatorVisitor.java index 23b70526bc8..40f869dac35 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/AnonymizedTranslatorVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/AnonymizedTranslatorVisitor.java @@ -20,6 +20,7 @@ import org.antlr.v4.runtime.ParserRuleContext; import org.apache.tinkerpop.gremlin.language.grammar.GremlinParser; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; import java.math.BigDecimal; import java.math.BigInteger; @@ -204,4 +205,9 @@ public Void visitDurationLiteral(final GremlinParser.DurationLiteralContext ctx) public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { return anonymize(ctx, ByteBuffer.class); } + + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + return anonymize(ctx, ProviderDefinedType.class); + } } diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/DotNetTranslateVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/DotNetTranslateVisitor.java index 11546c6ca2d..7c296572cf4 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/DotNetTranslateVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/DotNetTranslateVisitor.java @@ -1207,6 +1207,16 @@ public Void visitDurationLiteral(final GremlinParser.DurationLiteralContext ctx) return null; } + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + sb.append("new ProviderDefinedType("); + sb.append(ctx.stringLiteral().getText()); + sb.append(", "); + visitGenericMapLiteral(ctx.genericMapLiteral()); + sb.append(")"); + return null; + } + @Override public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { sb.append("Convert.FromBase64String("); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GoTranslateVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GoTranslateVisitor.java index fa3633800d5..48422f675e1 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GoTranslateVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GoTranslateVisitor.java @@ -376,6 +376,16 @@ public Void visitDurationLiteral(final GremlinParser.DurationLiteralContext ctx) return null; } + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + sb.append("&gremlingo.ProviderDefinedType{Name: "); + visitStringLiteral(ctx.stringLiteral()); + sb.append(", Properties: "); + visitGenericMapLiteral(ctx.genericMapLiteral()); + sb.append("}"); + return null; + } + @Override public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { final String base64Str = removeFirstAndLastCharacters(ctx.stringLiteral().getText()); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GroovyTranslateVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GroovyTranslateVisitor.java index 009233c9eb9..2fd655a5064 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GroovyTranslateVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/GroovyTranslateVisitor.java @@ -152,6 +152,16 @@ public Void visitDurationLiteral(final GremlinParser.DurationLiteralContext ctx) return null; } + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + sb.append("new ProviderDefinedType("); + sb.append(ctx.stringLiteral().getText()); + sb.append(", "); + visitGenericMapLiteral(ctx.genericMapLiteral()); + sb.append(")"); + return null; + } + @Override public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { sb.append("ByteBuffer.wrap(Base64.getDecoder().decode("); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavaTranslateVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavaTranslateVisitor.java index ee03430118f..6f28494526d 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavaTranslateVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavaTranslateVisitor.java @@ -270,6 +270,16 @@ public Void visitDurationLiteral(final GremlinParser.DurationLiteralContext ctx) return null; } + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + sb.append("new ProviderDefinedType("); + sb.append(ctx.stringLiteral().getText()); + sb.append(", "); + visitGenericMapLiteral(ctx.genericMapLiteral()); + sb.append(")"); + return null; + } + @Override public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { sb.append("ByteBuffer.wrap(Base64.getDecoder().decode("); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavascriptTranslateVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavascriptTranslateVisitor.java index 2a8e0ed07fe..cf6284d275f 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavascriptTranslateVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/JavascriptTranslateVisitor.java @@ -240,6 +240,16 @@ public Void visitDurationLiteral(final GremlinParser.DurationLiteralContext ctx) throw new TranslatorException("Duration literals are not supported in JavaScript"); } + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + sb.append("new ProviderDefinedType("); + visitStringLiteral(ctx.stringLiteral()); + sb.append(", "); + visitGenericMapLiteral(ctx.genericMapLiteral()); + sb.append(")"); + return null; + } + @Override public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { sb.append("Buffer.from("); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/PythonTranslateVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/PythonTranslateVisitor.java index e6297dc5395..694d6b4238d 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/PythonTranslateVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/PythonTranslateVisitor.java @@ -317,6 +317,16 @@ public Void visitDurationLiteral(final GremlinParser.DurationLiteralContext ctx) return null; } + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + sb.append("ProviderDefinedType("); + visitStringLiteral(ctx.stringLiteral()); + sb.append(", "); + visitGenericMapLiteral(ctx.genericMapLiteral()); + sb.append(")"); + return null; + } + @Override public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { sb.append("base64.b64decode("); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/TranslateVisitor.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/TranslateVisitor.java index a8d51f6b283..7e39b31482c 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/TranslateVisitor.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/language/translator/TranslateVisitor.java @@ -470,6 +470,12 @@ public Void visitBinaryLiteral(final GremlinParser.BinaryLiteralContext ctx) { return null; } + @Override + public Void visitPdtLiteral(final GremlinParser.PdtLiteralContext ctx) { + sb.append(ctx.getText()); + return null; + } + @Override public Void visitVariable(final GremlinParser.VariableContext ctx) { final String var = ctx.getText(); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/remote/RemoteConnection.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/remote/RemoteConnection.java index 15be3637cae..607c2b480fc 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/remote/RemoteConnection.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/remote/RemoteConnection.java @@ -23,6 +23,7 @@ import org.apache.tinkerpop.gremlin.process.traversal.GremlinLang; import org.apache.tinkerpop.gremlin.process.traversal.Traversal; import org.apache.tinkerpop.gremlin.structure.Transaction; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import java.lang.reflect.Constructor; import java.util.Iterator; @@ -55,6 +56,14 @@ public default Transaction tx() { */ public CompletableFuture> submitAsync(final GremlinLang gremlinLang) throws RemoteConnectionException; + /** + * Returns the {@link ProviderDefinedTypeRegistry} associated with this connection, or {@code null} if none. + * Used by the gremlin-lang translator for registry-based dehydration of unknown types. + */ + public default ProviderDefinedTypeRegistry getPdtRegistry() { + return null; + } + /** * Create a {@link RemoteConnection} from a {@code Configuration} object. The configuration must contain a * {@code gremlin.remote.remoteConnectionClass} key which is the fully qualified class name of a diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLang.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLang.java index 08f124015c2..14d153e7ee8 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLang.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLang.java @@ -30,6 +30,10 @@ import org.apache.tinkerpop.gremlin.structure.Column; import org.apache.tinkerpop.gremlin.structure.T; import org.apache.tinkerpop.gremlin.structure.Vertex; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeAdapter; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import org.apache.tinkerpop.gremlin.util.NumberHelper; import javax.lang.model.SourceVersion; @@ -48,6 +52,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.Base64; @@ -67,10 +72,15 @@ public class GremlinLang implements Cloneable, Serializable { private Map parameters = new HashMap<>(); private String unsupportedType = ""; private List optionsStrategies = new ArrayList<>(); + private ProviderDefinedTypeRegistry pdtRegistry; public GremlinLang() { } + public GremlinLang(final ProviderDefinedTypeRegistry pdtRegistry) { + this.pdtRegistry = pdtRegistry; + } + public GremlinLang(final String sourceName, final Object... arguments) { addToGremlin(sourceName, arguments); } @@ -179,6 +189,11 @@ private String argAsString(final Object arg) { return String.format("Binary(\"%s\")", Base64.getEncoder().encodeToString((byte[]) arg)); } + if (arg instanceof ProviderDefinedType) { + final ProviderDefinedType pdt = (ProviderDefinedType) arg; + return "PDT(" + argAsString(pdt.getName()) + "," + asString((Map) pdt.getProperties()) + ")"; + } + if (arg instanceof Enum) { // special handling for enums with additional interfaces if (arg instanceof T) @@ -252,6 +267,19 @@ private String argAsString(final Object arg) { return ((Class) arg).getSimpleName(); } + if (pdtRegistry != null) { + final Optional> adapter = pdtRegistry.getAdapterByClass(arg.getClass()); + if (adapter.isPresent()) { + @SuppressWarnings("unchecked") + final Map props = ((ProviderDefinedTypeAdapter) adapter.get()).toProperties(arg); + return argAsString(new ProviderDefinedType(adapter.get().typeName(), props)); + } + } + + if (arg.getClass().isAnnotationPresent(ProviderDefined.class)) { + return argAsString(ProviderDefinedType.from(arg)); + } + unsupportedType = arg.getClass().getSimpleName(); return arg.toString(); } @@ -532,6 +560,20 @@ public List getOptionsStrategies() { return optionsStrategies; } + /** + * Sets the {@link ProviderDefinedTypeRegistry} used for registry-based dehydration of unknown types. + */ + public void setPdtRegistry(final ProviderDefinedTypeRegistry pdtRegistry) { + this.pdtRegistry = pdtRegistry; + } + + /** + * Gets the {@link ProviderDefinedTypeRegistry} used for registry-based dehydration. + */ + public ProviderDefinedTypeRegistry getPdtRegistry() { + return this.pdtRegistry; + } + public boolean isEmpty() { return this.gremlin.length() == 0; } @@ -565,6 +607,7 @@ public GremlinLang clone() { clone.gremlin.append(gremlin); clone.optionsStrategies = new ArrayList<>(this.optionsStrategies); clone.unsupportedType = this.unsupportedType; + clone.pdtRegistry = this.pdtRegistry; return clone; } catch (final CloneNotSupportedException e) { throw new IllegalStateException(e.getMessage(), e); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/dsl/graph/GraphTraversalSource.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/dsl/graph/GraphTraversalSource.java index 684d9e8609a..aec1d6fadd8 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/dsl/graph/GraphTraversalSource.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/process/traversal/dsl/graph/GraphTraversalSource.java @@ -102,6 +102,9 @@ public GraphTraversalSource(final RemoteConnection connection) { this(EmptyGraph.instance(), TraversalStrategies.GlobalCache.getStrategies(EmptyGraph.class).clone()); this.connection = connection; this.strategies.addStrategies(new RemoteStrategy(connection)); + if (connection.getPdtRegistry() != null) { + this.gremlinLang.setPdtRegistry(connection.getPdtRegistry()); + } } @Override diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/DataType.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/DataType.java index 7c2e347f6e5..6ca99db8401 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/DataType.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/DataType.java @@ -58,7 +58,8 @@ public enum DataType { CHAR(0X80), DURATION(0X81), - CUSTOM(0), + COMPOSITE_PDT(0xF0), + MARKER(0XFD), UNSPECIFIED_NULL(0XFE); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryReader.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryReader.java index fb3ef553eba..5f3cf240c6e 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryReader.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryReader.java @@ -19,6 +19,8 @@ package org.apache.tinkerpop.gremlin.structure.io.binary; import org.apache.tinkerpop.gremlin.structure.io.Buffer; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import java.io.IOException; @@ -48,13 +50,19 @@ */ public class GraphBinaryReader { private final TypeSerializerRegistry registry; + private final ProviderDefinedTypeRegistry pdtRegistry; public GraphBinaryReader() { this(TypeSerializerRegistry.INSTANCE); } public GraphBinaryReader(final TypeSerializerRegistry registry) { + this(registry, null); + } + + public GraphBinaryReader(final TypeSerializerRegistry registry, final ProviderDefinedTypeRegistry pdtRegistry) { this.registry = registry; + this.pdtRegistry = pdtRegistry; } /** @@ -95,14 +103,11 @@ public T read(final Buffer buffer) throws IOException { return null; } - TypeSerializer serializer; - if (type != DataType.CUSTOM) { - serializer = registry.getSerializer(type); - } else { - final String customTypeName = this.readValue(buffer, String.class, false); - serializer = registry.getSerializerForCustomType(customTypeName); + final TypeSerializer serializer = registry.getSerializer(type); + final T result = serializer.read(buffer, this); + if (pdtRegistry != null && result instanceof ProviderDefinedType) { + return (T) pdtRegistry.hydrate((ProviderDefinedType) result); } - - return serializer.read(buffer, this); + return result; } } diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryWriter.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryWriter.java index 57190153889..ef4c07ccdac 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryWriter.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryWriter.java @@ -18,8 +18,9 @@ */ package org.apache.tinkerpop.gremlin.structure.io.binary; -import org.apache.tinkerpop.gremlin.structure.io.binary.types.CustomTypeSerializer; +import org.apache.tinkerpop.gremlin.structure.io.binary.types.ProviderDefinedTypeSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.TransformSerializer; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; import org.apache.tinkerpop.gremlin.structure.io.Buffer; import java.io.IOException; @@ -50,7 +51,6 @@ public class GraphBinaryWriter { public final static byte VERSION_BYTE = (byte)0x84; public final static byte BULKED_BYTE = (byte)0x01; private final static byte[] unspecifiedNullBytes = new byte[] { DataType.UNSPECIFIED_NULL.getCodeByte(), 0x01}; - private final static byte[] customTypeCodeBytes = new byte[] { DataType.CUSTOM.getCodeByte() }; public GraphBinaryWriter() { this(TypeSerializerRegistry.INSTANCE); @@ -76,6 +76,10 @@ public void writeValue(final T value, final Buffer buffer, final boolean nul final Class objectClass = value.getClass(); final TypeSerializer serializer = (TypeSerializer) registry.getSerializer(objectClass); + if (serializer instanceof ProviderDefinedTypeSerializer && !(value instanceof ProviderDefinedType)) { + serializer.writeValue((T) ProviderDefinedType.from(value), buffer, this, nullable); + return; + } serializer.writeValue(value, buffer, this, nullable); } @@ -92,13 +96,11 @@ public void write(final T value, final Buffer buffer) throws IOException { final Class objectClass = value.getClass(); final TypeSerializer serializer = (TypeSerializer) registry.getSerializer(objectClass); - if (serializer instanceof CustomTypeSerializer) { - // It's a custom type - CustomTypeSerializer customTypeSerializer = (CustomTypeSerializer) serializer; - - buffer.writeBytes(customTypeCodeBytes); - writeValue(customTypeSerializer.getTypeName(), buffer, false); - customTypeSerializer.write(value, buffer, this); + if (serializer instanceof ProviderDefinedTypeSerializer && !(value instanceof ProviderDefinedType)) { + // Convert @ProviderDefined-annotated object to ProviderDefinedType, then re-enter write(). + // On re-entry, ProviderDefinedType.class is directly registered in the registry, + // and the instanceof guard prevents double-wrapping. + write((T) ProviderDefinedType.from(value), buffer); return; } diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/TypeSerializerRegistry.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/TypeSerializerRegistry.java index f3d99212522..c0a3234aac0 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/TypeSerializerRegistry.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/TypeSerializerRegistry.java @@ -31,13 +31,13 @@ import org.apache.tinkerpop.gremlin.structure.T; import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.structure.VertexProperty; -import org.apache.tinkerpop.gremlin.structure.io.IoRegistry; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; import org.apache.tinkerpop.gremlin.structure.io.binary.types.BigDecimalSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.BigIntegerSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.BulkSetSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.BinarySerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.CharSerializer; -import org.apache.tinkerpop.gremlin.structure.io.binary.types.CustomTypeSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.DurationSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.EdgeSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.EnumSerializer; @@ -48,6 +48,7 @@ import org.apache.tinkerpop.gremlin.structure.io.binary.types.DateTimeSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.PathSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.PropertySerializer; +import org.apache.tinkerpop.gremlin.structure.io.binary.types.ProviderDefinedTypeSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.SetSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.SingleTypeSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.StringSerializer; @@ -57,7 +58,6 @@ import org.apache.tinkerpop.gremlin.structure.io.binary.types.UUIDSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.VertexPropertySerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.types.VertexSerializer; -import org.javatuples.Pair; import java.io.IOException; import java.lang.reflect.Modifier; @@ -120,7 +120,8 @@ public static Builder build() { new RegistryEntry<>(Character.class, new CharSerializer()), new RegistryEntry<>(Duration.class, new DurationSerializer()), - new RegistryEntry<>(OffsetDateTime.class, new DateTimeSerializer()) + new RegistryEntry<>(OffsetDateTime.class, new DateTimeSerializer()), + new RegistryEntry<>(ProviderDefinedType.class, new ProviderDefinedTypeSerializer()) }; public static final TypeSerializerRegistry INSTANCE = build().create(); @@ -140,39 +141,10 @@ public static class Builder { *

*/ public
Builder add(final Class
type, final TypeSerializer
serializer) { - if (serializer.getDataType() == DataType.CUSTOM) { - throw new IllegalArgumentException("DataType can not be CUSTOM, use addCustomType() method instead"); - } - if (serializer.getDataType() == DataType.UNSPECIFIED_NULL) { throw new IllegalArgumentException("Adding a serializer for a UNSPECIFIED_NULL is not permitted"); } - if (serializer instanceof CustomTypeSerializer) { - throw new IllegalArgumentException( - "CustomTypeSerializer implementations are reserved for customtypes"); - } - - list.add(new RegistryEntry<>(type, serializer)); - return this; - } - - /** - * Adds a serializer for a custom type. - */ - public
Builder addCustomType(final Class
type, final CustomTypeSerializer
serializer) { - if (serializer == null) { - throw new NullPointerException("serializer can not be null"); - } - - if (serializer.getDataType() != DataType.CUSTOM) { - throw new IllegalArgumentException("Custom serializer must use CUSTOM data type"); - } - - if (serializer.getTypeName() == null) { - throw new NullPointerException("serializer custom type name can not be null"); - } - list.add(new RegistryEntry<>(type, serializer)); return this; } @@ -185,21 +157,6 @@ public Builder withFallbackResolver(final Function, TypeSerializer> return this; } - /** - * Add {@link CustomTypeSerializer} by way of an {@link IoRegistry}. The registry entries should be bound to - * {@link GraphBinaryIo}. - */ - public Builder addRegistry(final IoRegistry registry) { - if (null == registry) throw new IllegalArgumentException("The registry cannot be null"); - - final List> classSerializers = registry.find(GraphBinaryIo.class, CustomTypeSerializer.class); - for (Pair cs : classSerializers) { - addCustomType(cs.getValue0(), cs.getValue1()); - } - - return this; - } - /** * Creates a new {@link TypeSerializerRegistry} instance based on the serializers added. */ @@ -225,15 +182,6 @@ public DataType getDataType() { return typeSerializer.getDataType(); } - public String getCustomTypeName() { - if (getDataType() != DataType.CUSTOM) { - return null; - } - - final CustomTypeSerializer customTypeSerializer = (CustomTypeSerializer) typeSerializer; - return customTypeSerializer.getTypeName(); - } - public TypeSerializer
getTypeSerializer() { return typeSerializer; } @@ -242,7 +190,6 @@ public TypeSerializer
getTypeSerializer() { private final Map, TypeSerializer> serializers = new HashMap<>(); private final Map, TypeSerializer> serializersByInterface = new LinkedHashMap<>(); private final Map> serializersByDataType = new HashMap<>(); - private final Map serializersByCustomTypeName = new HashMap<>(); private Function, TypeSerializer> fallbackResolver; /** @@ -291,9 +238,7 @@ private void put(final RegistryEntry entry) { serializersByInterface.put(type, serializer); } - if (serializer.getDataType() == DataType.CUSTOM) { - serializersByCustomTypeName.put(entry.getCustomTypeName(), (CustomTypeSerializer) serializer); - } else if (serializer.getDataType() != null) { + if (serializer.getDataType() != null) { serializersByDataType.put(serializer.getDataType(), serializer); } } @@ -333,7 +278,15 @@ public
TypeSerializer
getSerializer(final Class
type) throws IOExce serializer = fallbackResolver.apply(type); } - validateInstance(serializer, type.getTypeName()); + if (null == serializer && type.isAnnotationPresent(ProviderDefined.class)) { + serializer = serializersByDataType.get(DataType.COMPOSITE_PDT); + } + + if (serializer == null) { + throw new IOException(String.format( + "Serializer not found for type %s. If this is a provider-defined type, annotate the class with @ProviderDefined.", + type.getTypeName())); + } // Store the lookup match to avoid looking it up in the future serializersByImplementation.put(type, serializer); @@ -342,26 +295,9 @@ public
TypeSerializer
getSerializer(final Class
type) throws IOExce } public
TypeSerializer
getSerializer(final DataType dataType) throws IOException { - if (dataType == DataType.CUSTOM) { - throw new IllegalArgumentException("Custom type serializers can not be retrieved using this method"); - } - return validateInstance(serializersByDataType.get(dataType), dataType.toString()); } - /** - * Gets the serializer for a given custom type name. - */ - public
CustomTypeSerializer
getSerializerForCustomType(final String name) throws IOException { - final CustomTypeSerializer serializer = serializersByCustomTypeName.get(name); - - if (serializer == null) { - throw new IOException(String.format("Serializer for custom type '%s' not found", name)); - } - - return serializer; - } - private static TypeSerializer validateInstance(final TypeSerializer serializer, final String typeName) throws IOException { if (serializer == null) { throw new IOException(String.format("Serializer for type %s not found", typeName)); diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/ProviderDefinedTypeSerializer.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/ProviderDefinedTypeSerializer.java new file mode 100644 index 00000000000..d45b4cf2a94 --- /dev/null +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/ProviderDefinedTypeSerializer.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.binary.types; + +import org.apache.tinkerpop.gremlin.structure.io.Buffer; +import org.apache.tinkerpop.gremlin.structure.io.binary.DataType; +import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryReader; +import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryWriter; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; + +import java.io.IOException; +import java.util.Map; + +public class ProviderDefinedTypeSerializer extends SimpleTypeSerializer { + + public ProviderDefinedTypeSerializer() { + super(DataType.COMPOSITE_PDT); + } + + @Override + protected ProviderDefinedType readValue(final Buffer buffer, final GraphBinaryReader context) throws IOException { + final String name = context.read(buffer); + if (name == null || name.isEmpty()) + throw new IOException("ProviderDefinedType name cannot be null or empty"); + final Map properties = context.read(buffer); + for (final Object key : properties.keySet()) { + if (!(key instanceof String)) + throw new IOException("ProviderDefinedType properties map must have String keys, found: " + key.getClass().getName()); + } + @SuppressWarnings("unchecked") + final Map typedProperties = (Map) (Map) properties; + return new ProviderDefinedType(name, typedProperties); + } + + @Override + protected void writeValue(final ProviderDefinedType value, final Buffer buffer, final GraphBinaryWriter context) throws IOException { + context.write(value.getName(), buffer); + context.write(value.getProperties(), buffer); + } +} diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONMapper.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONMapper.java index 3da5c4e367f..6877b67c692 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONMapper.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONMapper.java @@ -21,6 +21,7 @@ import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.io.IoRegistry; import org.apache.tinkerpop.gremlin.structure.io.Mapper; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import org.apache.tinkerpop.shaded.jackson.annotation.JsonTypeInfo; import org.apache.tinkerpop.shaded.jackson.core.JsonFactory; import org.apache.tinkerpop.shaded.jackson.core.JsonGenerator; @@ -71,6 +72,7 @@ public class GraphSONMapper implements Mapper { private final GraphSONVersion version; private final TypeInfo typeInfo; private final StreamReadConstraints streamReadConstraints; + private final ProviderDefinedTypeRegistry pdtRegistry; private GraphSONMapper(final Builder builder) { this.customModules = builder.customModules; @@ -79,6 +81,7 @@ private GraphSONMapper(final Builder builder) { this.version = builder.version; this.streamReadConstraints = builder.streamReadConstraintsBuilder.build(); this.typeInfo = builder.typeInfo; + this.pdtRegistry = builder.pdtRegistry; } @Override @@ -89,6 +92,9 @@ public ObjectMapper createMapper() { } final GraphSONModule graphSONModule = version.getBuilder().create(normalize, typeInfo); + if (pdtRegistry != null && graphSONModule instanceof GraphSONModule.GraphSONModuleV4) { + ((GraphSONModule.GraphSONModuleV4) graphSONModule).setPdtRegistry(pdtRegistry); + } om.registerModule(graphSONModule); customModules.forEach(om::registerModule); @@ -186,6 +192,7 @@ public static Builder build(final GraphSONMapper mapper) { builder.loadCustomModules = mapper.loadCustomSerializers; builder.normalize = mapper.normalize; builder.typeInfo = mapper.typeInfo; + builder.pdtRegistry = mapper.pdtRegistry; builder.streamReadConstraintsBuilder = mapper.streamReadConstraints.rebuild(); return builder; @@ -217,6 +224,7 @@ public static class Builder implements Mapper.Builder { private StreamReadConstraints.Builder streamReadConstraintsBuilder = StreamReadConstraints.builder() .maxNumberLength(DEFAULT_MAX_NUMBER_LENGTH); private TypeInfo typeInfo = null; + private ProviderDefinedTypeRegistry pdtRegistry = null; private Builder() { } @@ -301,6 +309,15 @@ public Builder typeInfo(final TypeInfo typeInfo) { return this; } + /** + * Set the {@link ProviderDefinedTypeRegistry} to enable automatic hydration of + * {@link org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType} values during deserialization. + */ + public Builder pdtRegistry(final ProviderDefinedTypeRegistry pdtRegistry) { + this.pdtRegistry = pdtRegistry; + return this; + } + public Builder maxNumberLength(final int maxNumLength) { this.streamReadConstraintsBuilder.maxNumberLength(maxNumLength); return this; diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONModule.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONModule.java index 6e98642d7ba..5cdb5cb1c32 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONModule.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/GraphSONModule.java @@ -81,6 +81,8 @@ import org.apache.tinkerpop.gremlin.structure.T; import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.structure.VertexProperty; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import org.apache.tinkerpop.gremlin.structure.util.star.DirectionalStarGraph; import org.apache.tinkerpop.gremlin.structure.util.star.StarGraphGraphSONSerializerV1; import org.apache.tinkerpop.gremlin.structure.util.star.StarGraphGraphSONSerializerV2; @@ -156,12 +158,15 @@ static final class GraphSONModuleV4 extends GraphSONModule { put(Path.class, "Path"); put(VertexProperty.class, "VertexProperty"); put(Tree.class, "Tree"); + put(ProviderDefinedType.class, "CompositePdt"); Stream.of( Direction.class, Merge.class, T.class).forEach(e -> put(e, e.getSimpleName())); }}); + private final PdtGraphSONSerializersV4.ProviderDefinedTypeJacksonDeserializer pdtDeserializer; + /** * Constructs a new object. */ @@ -178,6 +183,7 @@ protected GraphSONModuleV4(final boolean normalize, final TypeInfo typeInfo) { addSerializer(Path.class, new GraphSONSerializersV4.PathJacksonSerializer()); addSerializer(DirectionalStarGraph.class, new StarGraphGraphSONSerializerV4(normalize)); addSerializer(Tree.class, new GraphSONSerializersV4.TreeJacksonSerializer()); + addSerializer(ProviderDefinedType.class, new PdtGraphSONSerializersV4.ProviderDefinedTypeJacksonSerializer()); // java.util - use the standard jackson serializers for collections when types aren't embedded if (typeInfo != TypeInfo.NO_TYPES) { @@ -208,6 +214,8 @@ protected GraphSONModuleV4(final boolean normalize, final TypeInfo typeInfo) { addDeserializer(Path.class, new GraphSONSerializersV4.PathJacksonDeserializer()); addDeserializer(VertexProperty.class, new GraphSONSerializersV4.VertexPropertyJacksonDeserializer()); addDeserializer(Tree.class, new GraphSONSerializersV4.TreeJacksonDeserializer()); + pdtDeserializer = new PdtGraphSONSerializersV4.ProviderDefinedTypeJacksonDeserializer(); + addDeserializer(ProviderDefinedType.class, pdtDeserializer); // java.util - use the standard jackson serializers for collections when types aren't embedded if (typeInfo != TypeInfo.NO_TYPES) { @@ -232,6 +240,10 @@ public static Builder build() { return new Builder(); } + void setPdtRegistry(final ProviderDefinedTypeRegistry registry) { + pdtDeserializer.setRegistry(registry); + } + @Override public Map getTypeDefinitions() { return TYPE_DEFINITIONS; diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/PdtGraphSONSerializersV4.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/PdtGraphSONSerializersV4.java new file mode 100644 index 00000000000..fb4cd848a36 --- /dev/null +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/graphson/PdtGraphSONSerializersV4.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.graphson; + +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; +import org.apache.tinkerpop.shaded.jackson.core.JsonGenerator; +import org.apache.tinkerpop.shaded.jackson.core.JsonParser; +import org.apache.tinkerpop.shaded.jackson.core.JsonToken; +import org.apache.tinkerpop.shaded.jackson.databind.DeserializationContext; +import org.apache.tinkerpop.shaded.jackson.databind.SerializerProvider; +import org.apache.tinkerpop.shaded.jackson.databind.deser.std.StdDeserializer; +import org.apache.tinkerpop.shaded.jackson.databind.ser.std.StdScalarSerializer; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * GraphSON V4 serializers for {@link ProviderDefinedType}. + */ +final class PdtGraphSONSerializersV4 { + + private PdtGraphSONSerializersV4() { + } + + final static class ProviderDefinedTypeJacksonSerializer extends StdScalarSerializer { + + public ProviderDefinedTypeJacksonSerializer() { + super(ProviderDefinedType.class); + } + + @Override + public void serialize(final ProviderDefinedType pdt, final JsonGenerator jsonGenerator, + final SerializerProvider serializerProvider) throws IOException { + jsonGenerator.writeStartObject(); + jsonGenerator.writeStringField("type", pdt.getName()); + jsonGenerator.writeFieldName("fields"); + jsonGenerator.writeStartObject(); + for (final Map.Entry entry : pdt.getProperties().entrySet()) { + jsonGenerator.writeFieldName(entry.getKey()); + jsonGenerator.writeObject(entry.getValue()); + } + jsonGenerator.writeEndObject(); + jsonGenerator.writeEndObject(); + } + } + + static class ProviderDefinedTypeJacksonDeserializer extends StdDeserializer { + + private ProviderDefinedTypeRegistry registry; + + public ProviderDefinedTypeJacksonDeserializer() { + super(ProviderDefinedType.class); + } + + void setRegistry(final ProviderDefinedTypeRegistry registry) { + this.registry = registry; + } + + @Override + public ProviderDefinedType deserialize(final JsonParser jsonParser, final DeserializationContext deserializationContext) + throws IOException { + String typeName = null; + Map fields = new LinkedHashMap<>(); + + while (jsonParser.nextToken() != JsonToken.END_OBJECT) { + final String fieldName = jsonParser.getCurrentName(); + if ("type".equals(fieldName)) { + jsonParser.nextToken(); + typeName = jsonParser.getText(); + } else if ("fields".equals(fieldName)) { + jsonParser.nextToken(); + while (jsonParser.nextToken() != JsonToken.END_OBJECT) { + final String key = jsonParser.getCurrentName(); + jsonParser.nextToken(); + final Object value = deserializationContext.readValue(jsonParser, Object.class); + fields.put(key, value); + } + } + } + + final ProviderDefinedType pdt = new ProviderDefinedType(typeName, fields); + if (registry != null) { + final Object hydrated = registry.hydrate(pdt); + if (hydrated instanceof ProviderDefinedType) + return (ProviderDefinedType) hydrated; + // Store hydrated object back as a single-entry PDT so the typed result is accessible. + // This preserves the return type contract while enabling hydration. + return pdt.withHydrated(hydrated); + } + return pdt; + } + + @Override + public boolean isCachable() { + return true; + } + } +} diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefined.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefined.java new file mode 100644 index 00000000000..51611923760 --- /dev/null +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefined.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.pdt; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marks a class as a provider-defined type for serialization via {@link ProviderDefinedType}. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ProviderDefined { + String name() default ""; + String[] includedFields() default {}; + String[] excludedFields() default {}; +} diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedType.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedType.java new file mode 100644 index 00000000000..3a67209db91 --- /dev/null +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedType.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.pdt; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +/** + * An immutable representation of a provider-defined type consisting of a name and a map of properties. + */ +public final class ProviderDefinedType { + + private static final ConcurrentHashMap, FieldCache> FIELD_CACHE = new ConcurrentHashMap<>(); + + private final String name; + private final Map properties; + private Object hydrated; + + public ProviderDefinedType(final String name, final Map properties) { + if (name == null || name.isEmpty()) + throw new IllegalArgumentException("name cannot be null or empty"); + if (properties == null) + throw new IllegalArgumentException("properties cannot be null"); + this.name = name; + this.properties = Collections.unmodifiableMap(new LinkedHashMap<>(properties)); + } + + /** + * Creates a {@code ProviderDefinedType} from an object annotated with {@link ProviderDefined}. + */ + public static ProviderDefinedType from(final Object obj) { + if (obj == null) + throw new IllegalArgumentException("obj cannot be null"); + + final Class clazz = obj.getClass(); + final FieldCache cache = FIELD_CACHE.computeIfAbsent(clazz, ProviderDefinedType::buildCache); + + final Map props = new LinkedHashMap<>(); + for (final Field field : cache.fields) { + try { + props.put(field.getName(), field.get(obj)); + } catch (Exception e) { + throw new RuntimeException("Failed to read field '" + field.getName() + "' from " + clazz.getName(), e); + } + } + + return new ProviderDefinedType(cache.name, props); + } + + /** + * Package-private access to the resolved type name for a {@link ProviderDefined}-annotated class. + * Validates the annotation and field configuration via the shared field cache. + */ + static String resolveTypeName(final Class clazz) { + return FIELD_CACHE.computeIfAbsent(clazz, ProviderDefinedType::buildCache).name; + } + + /** + * Package-private access to the resolved serializable fields for a {@link ProviderDefined}-annotated class. + */ + static Field[] resolveFields(final Class clazz) { + return FIELD_CACHE.computeIfAbsent(clazz, ProviderDefinedType::buildCache).fields; + } + + private static FieldCache buildCache(final Class clazz) { + final ProviderDefined annotation = clazz.getAnnotation(ProviderDefined.class); + if (annotation == null) + throw new IllegalArgumentException(clazz.getName() + " is not annotated with @ProviderDefined"); + + final String typeName = annotation.name().isEmpty() ? clazz.getSimpleName() : annotation.name(); + final String[] included = annotation.includedFields(); + final String[] excluded = annotation.excludedFields(); + + if (included.length > 0 && excluded.length > 0) + throw new IllegalArgumentException("@ProviderDefined cannot specify both includedFields and excludedFields"); + + final Set includedSet = included.length > 0 ? new HashSet<>(Arrays.asList(included)) : null; + final Set excludedSet = excluded.length > 0 ? new HashSet<>(Arrays.asList(excluded)) : Collections.emptySet(); + + final Field[] allFields = getAllFields(clazz).toArray(new Field[0]); + final Field[] filtered = Arrays.stream(allFields) + .filter(f -> !f.isSynthetic()) + .filter(f -> { + if (includedSet != null) return includedSet.contains(f.getName()); + return !excludedSet.contains(f.getName()); + }) + .peek(f -> f.setAccessible(true)) + .toArray(Field[]::new); + + return new FieldCache(typeName, filtered); + } + + private static List getAllFields(final Class clazz) { + final List fields = new ArrayList<>(); + Class current = clazz; + while (current != null && current != Object.class) { + fields.addAll(Arrays.asList(current.getDeclaredFields())); + current = current.getSuperclass(); + } + return fields; + } + + private static class FieldCache { + final String name; + final Field[] fields; + + FieldCache(final String name, final Field[] fields) { + this.name = name; + this.fields = fields; + } + } + + public String getName() { + return name; + } + + public Map getProperties() { + return properties; + } + + /** + * Returns a copy of this PDT with the hydrated object attached. + */ + public ProviderDefinedType withHydrated(final Object hydrated) { + final ProviderDefinedType copy = new ProviderDefinedType(this.name, this.properties); + copy.hydrated = hydrated; + return copy; + } + + /** + * Returns the hydrated object if this PDT was hydrated by a {@link ProviderDefinedTypeRegistry}, or {@code null}. + */ + public Object getHydrated() { + return hydrated; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (!(o instanceof ProviderDefinedType)) return false; + final ProviderDefinedType that = (ProviderDefinedType) o; + return name.equals(that.name) && properties.equals(that.properties); + } + + @Override + public int hashCode() { + return Objects.hash(name, properties); + } + + @Override + public String toString() { + return "pdt[" + name + "]" + properties; + } +} diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/CustomTypeSerializer.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeAdapter.java similarity index 68% rename from gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/CustomTypeSerializer.java rename to gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeAdapter.java index a54580cab0c..86880a0584e 100644 --- a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/CustomTypeSerializer.java +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeAdapter.java @@ -16,17 +16,16 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.tinkerpop.gremlin.structure.io.binary.types; +package org.apache.tinkerpop.gremlin.structure.io.pdt; -import org.apache.tinkerpop.gremlin.structure.io.binary.TypeSerializer; +import java.util.Map; /** - * Represents a serializer for a custom (provider specific) serializer. - * @param + * Adapter for converting between a typed object and a {@link ProviderDefinedType} property map. */ -public interface CustomTypeSerializer extends TypeSerializer { - /** - * Gets the custom type name. - */ - String getTypeName(); +public interface ProviderDefinedTypeAdapter { + String typeName(); + Class targetClass(); + Map toProperties(T obj); + T fromProperties(Map properties); } diff --git a/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeRegistry.java b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeRegistry.java new file mode 100644 index 00000000000..10d63bc92f9 --- /dev/null +++ b/gremlin-core/src/main/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeRegistry.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.pdt; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.ServiceLoader; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Registry for {@link ProviderDefinedTypeAdapter} instances that supports hydration of + * {@link ProviderDefinedType} values into typed objects. + */ +public final class ProviderDefinedTypeRegistry { + + private static final Logger logger = LoggerFactory.getLogger(ProviderDefinedTypeRegistry.class); + + private final Map> adaptersByName = new ConcurrentHashMap<>(); + private final Map, ProviderDefinedTypeAdapter> adaptersByClass = new ConcurrentHashMap<>(); + + private ProviderDefinedTypeRegistry() {} + + /** + * Creates a registry populated via {@link ServiceLoader} discovery. + */ + @SuppressWarnings("rawtypes") + public static ProviderDefinedTypeRegistry build() { + final ProviderDefinedTypeRegistry registry = new ProviderDefinedTypeRegistry(); + for (final ProviderDefinedTypeAdapter adapter : ServiceLoader.load(ProviderDefinedTypeAdapter.class)) { + registry.register(adapter); + } + return registry; + } + + /** + * Creates an empty registry for manual registration. + */ + public static ProviderDefinedTypeRegistry empty() { + return new ProviderDefinedTypeRegistry(); + } + + public void register(final ProviderDefinedTypeAdapter adapter) { + adaptersByName.put(adapter.typeName(), adapter); + adaptersByClass.put(adapter.targetClass(), adapter); + } + + /** + * Registers one or more classes annotated with {@link ProviderDefined} for automatic round-trip hydration. + * An adapter is synthesized from the annotation metadata using reflection. + * + * @throws IllegalArgumentException if any class is not annotated with {@link ProviderDefined} + */ + public void register(final Class... annotatedClasses) { + for (final Class clazz : annotatedClasses) { + register(AnnotatedTypeAdapter.of(clazz)); + } + } + + public Optional> getAdapterByName(final String name) { + return Optional.ofNullable(adaptersByName.get(name)); + } + + public Optional> getAdapterByClass(final Class clazz) { + return Optional.ofNullable(adaptersByClass.get(clazz)); + } + + /** + * Attempts to hydrate a {@link ProviderDefinedType} into a typed object using a registered adapter. + * Recursively hydrates nested PDT values in the properties map (including those inside Lists, Sets, + * and Maps) before calling the adapter. + * Returns the original PDT if no adapter is found or if the adapter throws an exception. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public Object hydrate(final ProviderDefinedType pdt) { + final ProviderDefinedTypeAdapter adapter = adaptersByName.get(pdt.getName()); + if (adapter == null) + return pdt; + + // recursively hydrate nested PDTs in the properties map + final Map hydrated = new LinkedHashMap<>(); + for (final Map.Entry entry : pdt.getProperties().entrySet()) { + hydrated.put(entry.getKey(), hydrateValue(entry.getValue())); + } + + try { + return adapter.fromProperties(hydrated); + } catch (final Exception e) { + logger.warn("Failed to hydrate ProviderDefinedType '{}', returning raw PDT: {}", + pdt.getName(), e.getMessage()); + return pdt; + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private Object hydrateValue(final Object value) { + if (value instanceof ProviderDefinedType) + return hydrate((ProviderDefinedType) value); + if (value instanceof List) { + final List result = new ArrayList<>(); + for (final Object item : (List) value) + result.add(hydrateValue(item)); + return result; + } + if (value instanceof Set) { + final Set result = new LinkedHashSet<>(); + for (final Object item : (Set) value) + result.add(hydrateValue(item)); + return result; + } + if (value instanceof Map) { + final Map result = new LinkedHashMap<>(); + for (final Map.Entry entry : ((Map) value).entrySet()) + result.put(entry.getKey(), hydrateValue(entry.getValue())); + return result; + } + return value; + } + + /** + * A reflective adapter synthesized from a {@link ProviderDefined}-annotated class. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + private static final class AnnotatedTypeAdapter implements ProviderDefinedTypeAdapter { + private final String typeName; + private final Class targetClass; + private final Field[] fields; + + private AnnotatedTypeAdapter(final String typeName, final Class targetClass, final Field[] fields) { + this.typeName = typeName; + this.targetClass = targetClass; + this.fields = fields; + } + + static AnnotatedTypeAdapter of(final Class clazz) { + if (!clazz.isAnnotationPresent(ProviderDefined.class)) + throw new IllegalArgumentException(clazz.getName() + " is not annotated with @ProviderDefined"); + try { + clazz.getDeclaredConstructor(); + } catch (final NoSuchMethodException e) { + throw new IllegalArgumentException(clazz.getName() + + " must have a no-arg constructor for annotation-based hydration"); + } + // reuse ProviderDefinedType's validated, cached field/name resolution + return new AnnotatedTypeAdapter<>( + ProviderDefinedType.resolveTypeName(clazz), + clazz, + ProviderDefinedType.resolveFields(clazz)); + } + + @Override public String typeName() { return typeName; } + @Override public Class targetClass() { return targetClass; } + + @Override + public Map toProperties(final T obj) { + return ProviderDefinedType.from(obj).getProperties(); + } + + @Override + public T fromProperties(final Map properties) { + try { + final java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); + ctor.setAccessible(true); + final T obj = ctor.newInstance(); + for (final Field field : fields) { + final Object value = properties.get(field.getName()); + if (value != null) + field.set(obj, coerce(value, field.getType())); + } + return obj; + } catch (final ReflectiveOperationException e) { + throw new RuntimeException("Failed to hydrate " + targetClass.getName() + ": " + e, e); + } + } + + private static Object coerce(final Object value, final Class targetType) { + if (targetType.isInstance(value)) return value; + if (value instanceof Number) { + final Number n = (Number) value; + if (targetType == int.class || targetType == Integer.class) return n.intValue(); + if (targetType == long.class || targetType == Long.class) return n.longValue(); + if (targetType == double.class || targetType == Double.class) return n.doubleValue(); + if (targetType == float.class || targetType == Float.class) return n.floatValue(); + if (targetType == short.class || targetType == Short.class) return n.shortValue(); + if (targetType == byte.class || targetType == Byte.class) return n.byteValue(); + } + return value; + } + } +} diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/grammar/GeneralLiteralVisitorTest.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/grammar/GeneralLiteralVisitorTest.java index 71f4acce192..953565679bf 100644 --- a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/grammar/GeneralLiteralVisitorTest.java +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/grammar/GeneralLiteralVisitorTest.java @@ -24,6 +24,7 @@ import org.apache.tinkerpop.gremlin.structure.Direction; import org.apache.tinkerpop.gremlin.structure.T; import org.apache.tinkerpop.gremlin.structure.VertexProperty; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; import org.hamcrest.Matchers; import org.junit.Assert; import org.junit.Test; @@ -1049,4 +1050,44 @@ public void shouldFailOnInvalidBase64() { } } } + + public static class ValidPdtLiteralTest { + @Test + public void shouldParsePdtLiteral() { + final GremlinLexer lexer = new GremlinLexer(CharStreams.fromString("PDT(\"MyType\",[\"x\":1,\"y\":\"hello\"])")); + final GremlinParser parser = new GremlinParser(new CommonTokenStream(lexer)); + final GremlinParser.PdtLiteralContext ctx = parser.pdtLiteral(); + final Object result = new GenericLiteralVisitor(new GremlinAntlrToJava()).visitPdtLiteral(ctx); + assertThat(result, instanceOf(ProviderDefinedType.class)); + final ProviderDefinedType pdt = (ProviderDefinedType) result; + assertEquals("MyType", pdt.getName()); + assertEquals(1, pdt.getProperties().get("x")); + assertEquals("hello", pdt.getProperties().get("y")); + } + + @Test + public void shouldParsePdtLiteralWithEmptyMap() { + final GremlinLexer lexer = new GremlinLexer(CharStreams.fromString("PDT(\"Empty\",[:])")); + final GremlinParser parser = new GremlinParser(new CommonTokenStream(lexer)); + final GremlinParser.PdtLiteralContext ctx = parser.pdtLiteral(); + final Object result = new GenericLiteralVisitor(new GremlinAntlrToJava()).visitPdtLiteral(ctx); + assertThat(result, instanceOf(ProviderDefinedType.class)); + final ProviderDefinedType pdt = (ProviderDefinedType) result; + assertEquals("Empty", pdt.getName()); + assertTrue(pdt.getProperties().isEmpty()); + } + + @Test + public void shouldRejectNonStringMapKey() { + final GremlinLexer lexer = new GremlinLexer(CharStreams.fromString("PDT(\"Bad\",[1:\"value\"])")); + final GremlinParser parser = new GremlinParser(new CommonTokenStream(lexer)); + final GremlinParser.PdtLiteralContext ctx = parser.pdtLiteral(); + try { + new GenericLiteralVisitor(new GremlinAntlrToJava()).visitPdtLiteral(ctx); + fail("Expected IllegalArgumentException for non-String map key"); + } catch (final IllegalArgumentException e) { + assertTrue(e.getMessage().contains("PDT properties map must have String keys, found: java.lang.Integer")); + } + } + } } diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/translator/GremlinTranslatorTest.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/translator/GremlinTranslatorTest.java index 8e13823450b..42ee0aaaab4 100644 --- a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/translator/GremlinTranslatorTest.java +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/language/translator/GremlinTranslatorTest.java @@ -1480,6 +1480,15 @@ public static Collection data() { "g.inject(ByteBuffer.wrap(Base64.getDecoder().decode(\"AQID\")))", "g.inject(Buffer.from(\"AQID\",'base64'))", "g.inject(base64.b64decode('AQID'))"}, + {"g.inject(PDT(\"Point\",[\"x\":1,\"y\":2]))", + null, + "g.inject(providerdefinedtype0)", + "g.Inject(new ProviderDefinedType(\"Point\", new Dictionary {{ \"x\", 1 }, { \"y\", 2 }}))", + "g.Inject(&gremlingo.ProviderDefinedType{Name: \"Point\", Properties: map[interface{}]interface{}{\"x\": 1, \"y\": 2 }})", + "g.inject(new ProviderDefinedType(\"Point\", [\"x\":1, \"y\":2]))", + "g.inject(new ProviderDefinedType(\"Point\", new LinkedHashMap() {{ put(\"x\", 1); put(\"y\", 2); }}))", + "g.inject(new ProviderDefinedType(\"Point\", new Map([[\"x\", 1], [\"y\", 2]])))", + "g.inject(ProviderDefinedType('Point', { 'x': 1, 'y': 2 }))"}, }); } diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLangTest.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLangTest.java index 57649b697d2..e181393de76 100644 --- a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLangTest.java +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/process/traversal/GremlinLangTest.java @@ -28,6 +28,7 @@ import org.apache.tinkerpop.gremlin.structure.VertexProperty; import org.apache.tinkerpop.gremlin.structure.util.detached.DetachedVertex; import org.apache.tinkerpop.gremlin.structure.util.empty.EmptyGraph; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; import org.apache.tinkerpop.gremlin.structure.util.reference.ReferenceEdge; import org.apache.tinkerpop.gremlin.structure.util.reference.ReferenceVertex; import org.apache.tinkerpop.gremlin.util.DatetimeHelper; @@ -140,6 +141,19 @@ public static Iterable generateTestParameters() { {g.inject(new byte[]{1, 2, 3}), "g.inject(Binary(\"AQID\"))"}, {g.inject(new byte[]{}), "g.inject(Binary(\"\"))"}, {g.inject(new byte[]{0}), "g.inject(Binary(\"AA==\"))"}, + // PDT + {g.inject(new ProviderDefinedType("MyType", asMap("x", 1, "y", "hello"))), + "g.inject(PDT(\"MyType\",[\"x\":1,\"y\":\"hello\"]))"}, + {g.inject(new ProviderDefinedType("Empty", Collections.emptyMap())), + "g.inject(PDT(\"Empty\",[:]))"}, + // PDT with special characters in name + {g.inject(new ProviderDefinedType("say\"hello\"", asMap("v", 1))), + "g.inject(PDT(\"say\\\"hello\\\"\",[\"v\":1]))"}, + {g.inject(new ProviderDefinedType("back\\slash", asMap("v", 1))), + "g.inject(PDT(\"back\\\\slash\",[\"v\":1]))"}, + // Nested PDT + {g.inject(new ProviderDefinedType("Outer", asMap("inner", new ProviderDefinedType("Inner", asMap("v", 1))))), + "g.inject(PDT(\"Outer\",[\"inner\":PDT(\"Inner\",[\"v\":1])]))"}, }); } diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryWriterPdtTest.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryWriterPdtTest.java new file mode 100644 index 00000000000..863e2559da9 --- /dev/null +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/GraphBinaryWriterPdtTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.binary; + +import org.apache.tinkerpop.gremlin.structure.io.Buffer; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.junit.Test; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class GraphBinaryWriterPdtTest { + + private static final GraphBinaryReader reader = new GraphBinaryReader(); + private static final GraphBinaryWriter writer = new GraphBinaryWriter(); + + @ProviderDefined + static class TestPoint { + int x; + int y; + + TestPoint(int x, int y) { + this.x = x; + this.y = y; + } + } + + static class UnannotatedType { + int value = 1; + } + + @Test + public void shouldAutoConvertAnnotatedObjectToPdt() throws IOException { + final Buffer buffer = HeapBuffer.allocate(1024); + writer.write(new TestPoint(1, 2), buffer); + buffer.readerIndex(0); + + final ProviderDefinedType result = reader.read(buffer); + assertEquals("TestPoint", result.getName()); + assertEquals(1, result.getProperties().get("x")); + assertEquals(2, result.getProperties().get("y")); + } + + @Test + public void shouldThrowActionableMessageForUnannotatedType() { + final Buffer buffer = HeapBuffer.allocate(1024); + final IOException ex = assertThrows(IOException.class, () -> writer.write(new UnannotatedType(), buffer)); + assertTrue(ex.getMessage().contains("@ProviderDefined")); + assertTrue(ex.getMessage().contains("UnannotatedType")); + } + + @Test + public void shouldNotDoubleWrapProviderDefinedType() throws IOException { + final Map props = new LinkedHashMap<>(); + props.put("x", 1); + props.put("y", 2); + final ProviderDefinedType pdt = new ProviderDefinedType("TestPoint", props); + + final Buffer buffer = HeapBuffer.allocate(1024); + writer.write(pdt, buffer); + buffer.readerIndex(0); + + final ProviderDefinedType result = reader.read(buffer); + assertEquals(pdt, result); + } +} diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/HeapBuffer.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/HeapBuffer.java new file mode 100644 index 00000000000..e7f090cd1b2 --- /dev/null +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/HeapBuffer.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.binary; + +import org.apache.tinkerpop.gremlin.structure.io.Buffer; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; + +/** + * A simple heap-based {@link Buffer} implementation for unit testing in gremlin-core. + */ +public class HeapBuffer implements Buffer { + private byte[] data; + private int readerIndex; + private int writerIndex; + private int markedWriterIndex; + + public HeapBuffer(final int initialCapacity) { + this.data = new byte[initialCapacity]; + } + + public static HeapBuffer allocate(final int capacity) { + return new HeapBuffer(capacity); + } + + private void ensureCapacity(final int needed) { + if (writerIndex + needed > data.length) { + data = Arrays.copyOf(data, Math.max(data.length * 2, writerIndex + needed)); + } + } + + @Override public int readableBytes() { return writerIndex - readerIndex; } + @Override public int readerIndex() { return readerIndex; } + @Override public Buffer readerIndex(final int readerIndex) { this.readerIndex = readerIndex; return this; } + @Override public int writerIndex() { return writerIndex; } + @Override public Buffer writerIndex(final int writerIndex) { this.writerIndex = writerIndex; return this; } + @Override public Buffer markWriterIndex() { this.markedWriterIndex = writerIndex; return this; } + @Override public Buffer resetWriterIndex() { this.writerIndex = markedWriterIndex; return this; } + @Override public int capacity() { return data.length; } + @Override public boolean isDirect() { return false; } + + @Override + public boolean readBoolean() { return readByte() != 0; } + + @Override + public byte readByte() { return data[readerIndex++]; } + + @Override + public short readShort() { + short v = (short) ((data[readerIndex] & 0xFF) << 8 | (data[readerIndex + 1] & 0xFF)); + readerIndex += 2; + return v; + } + + @Override + public int readInt() { + int v = (data[readerIndex] & 0xFF) << 24 | (data[readerIndex + 1] & 0xFF) << 16 | + (data[readerIndex + 2] & 0xFF) << 8 | (data[readerIndex + 3] & 0xFF); + readerIndex += 4; + return v; + } + + @Override + public long readLong() { + long v = ((long)(data[readerIndex] & 0xFF) << 56) | ((long)(data[readerIndex+1] & 0xFF) << 48) | + ((long)(data[readerIndex+2] & 0xFF) << 40) | ((long)(data[readerIndex+3] & 0xFF) << 32) | + ((long)(data[readerIndex+4] & 0xFF) << 24) | ((long)(data[readerIndex+5] & 0xFF) << 16) | + ((long)(data[readerIndex+6] & 0xFF) << 8) | ((long)(data[readerIndex+7] & 0xFF)); + readerIndex += 8; + return v; + } + + @Override + public float readFloat() { return Float.intBitsToFloat(readInt()); } + + @Override + public double readDouble() { return Double.longBitsToDouble(readLong()); } + + @Override + public Buffer readBytes(final byte[] destination) { + System.arraycopy(data, readerIndex, destination, 0, destination.length); + readerIndex += destination.length; + return this; + } + + @Override + public Buffer readBytes(final byte[] destination, final int dstIndex, final int length) { + System.arraycopy(data, readerIndex, destination, dstIndex, length); + readerIndex += length; + return this; + } + + @Override + public Buffer readBytes(final ByteBuffer dst) { + int len = dst.remaining(); + dst.put(data, readerIndex, len); + readerIndex += len; + return this; + } + + @Override + public Buffer readBytes(final OutputStream out, final int length) throws IOException { + out.write(data, readerIndex, length); + readerIndex += length; + return this; + } + + @Override + public Buffer writeBoolean(final boolean value) { return writeByte(value ? 1 : 0); } + + @Override + public Buffer writeByte(final int value) { + ensureCapacity(1); + data[writerIndex++] = (byte) value; + return this; + } + + @Override + public Buffer writeShort(final int value) { + ensureCapacity(2); + data[writerIndex++] = (byte) (value >>> 8); + data[writerIndex++] = (byte) value; + return this; + } + + @Override + public Buffer writeInt(final int value) { + ensureCapacity(4); + data[writerIndex++] = (byte) (value >>> 24); + data[writerIndex++] = (byte) (value >>> 16); + data[writerIndex++] = (byte) (value >>> 8); + data[writerIndex++] = (byte) value; + return this; + } + + @Override + public Buffer writeLong(final long value) { + ensureCapacity(8); + data[writerIndex++] = (byte) (value >>> 56); + data[writerIndex++] = (byte) (value >>> 48); + data[writerIndex++] = (byte) (value >>> 40); + data[writerIndex++] = (byte) (value >>> 32); + data[writerIndex++] = (byte) (value >>> 24); + data[writerIndex++] = (byte) (value >>> 16); + data[writerIndex++] = (byte) (value >>> 8); + data[writerIndex++] = (byte) value; + return this; + } + + @Override + public Buffer writeFloat(final float value) { return writeInt(Float.floatToIntBits(value)); } + + @Override + public Buffer writeDouble(final double value) { return writeLong(Double.doubleToLongBits(value)); } + + @Override + public Buffer writeBytes(final byte[] src) { + ensureCapacity(src.length); + System.arraycopy(src, 0, data, writerIndex, src.length); + writerIndex += src.length; + return this; + } + + @Override + public Buffer writeBytes(final ByteBuffer src) { + int len = src.remaining(); + ensureCapacity(len); + src.get(data, writerIndex, len); + writerIndex += len; + return this; + } + + @Override + public Buffer writeBytes(final byte[] src, final int srcIndex, final int length) { + ensureCapacity(length); + System.arraycopy(src, srcIndex, data, writerIndex, length); + writerIndex += length; + return this; + } + + @Override public boolean release() { return true; } + @Override public Buffer retain() { return this; } + @Override public int referenceCount() { return 1; } + @Override public int nioBufferCount() { return 1; } + + @Override + public ByteBuffer[] nioBuffers() { + return new ByteBuffer[] { nioBuffer() }; + } + + @Override + public ByteBuffer[] nioBuffers(final int index, final int length) { + return new ByteBuffer[] { nioBuffer(index, length) }; + } + + @Override + public ByteBuffer nioBuffer() { + return ByteBuffer.wrap(data, readerIndex, readableBytes()).slice(); + } + + @Override + public ByteBuffer nioBuffer(final int index, final int length) { + return ByteBuffer.wrap(data, index, length).slice(); + } + + @Override + public Buffer getBytes(final int index, final byte[] dst) { + System.arraycopy(data, index, dst, 0, dst.length); + return this; + } +} diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/ProviderDefinedTypeSerializerTest.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/ProviderDefinedTypeSerializerTest.java new file mode 100644 index 00000000000..76f373609d9 --- /dev/null +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/binary/types/ProviderDefinedTypeSerializerTest.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.binary.types; + +import org.apache.tinkerpop.gremlin.structure.io.Buffer; +import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryReader; +import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryWriter; +import org.apache.tinkerpop.gremlin.structure.io.binary.HeapBuffer; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeAdapter; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class ProviderDefinedTypeSerializerTest { + + private static final GraphBinaryReader reader = new GraphBinaryReader(); + private static final GraphBinaryWriter writer = new GraphBinaryWriter(); + + private Buffer writeAndRead(final Object value) throws IOException { + final Buffer buffer = HeapBuffer.allocate(1024); + writer.write(value, buffer); + buffer.readerIndex(0); + return buffer; + } + + @Test + public void shouldRoundTripSimplePdt() throws IOException { + final Map props = new LinkedHashMap<>(); + props.put("x", 1); + props.put("y", "hello"); + final ProviderDefinedType pdt = new ProviderDefinedType("com.example.Point", props); + + final Buffer buffer = writeAndRead(pdt); + final ProviderDefinedType result = reader.read(buffer); + + assertEquals(pdt, result); + } + + @Test + public void shouldRoundTripPdtWithNullPropertyValue() throws IOException { + final Map props = new LinkedHashMap<>(); + props.put("name", "test"); + props.put("value", null); + final ProviderDefinedType pdt = new ProviderDefinedType("com.example.Nullable", props); + + final Buffer buffer = writeAndRead(pdt); + final ProviderDefinedType result = reader.read(buffer); + + assertEquals(pdt, result); + } + + @Test + public void shouldRoundTripNestedPdt() throws IOException { + final Map innerProps = new LinkedHashMap<>(); + innerProps.put("street", "123 Main"); + final ProviderDefinedType inner = new ProviderDefinedType("com.example.Address", innerProps); + + final Map outerProps = new LinkedHashMap<>(); + outerProps.put("name", "Alice"); + outerProps.put("address", inner); + final ProviderDefinedType outer = new ProviderDefinedType("com.example.Person", outerProps); + + final Buffer buffer = writeAndRead(outer); + final ProviderDefinedType result = reader.read(buffer); + + assertEquals(outer, result); + } + + @Test + public void shouldRoundTripPdtInsideList() throws IOException { + final Map props = Collections.singletonMap("id", 42); + final ProviderDefinedType pdt = new ProviderDefinedType("com.example.Item", props); + final List list = Arrays.asList(pdt, "other"); + + final Buffer buffer = writeAndRead(list); + final List result = reader.read(buffer); + + assertEquals(list, result); + } + + @Test + public void shouldRoundTripPdtInsideMapValue() throws IOException { + final Map props = Collections.singletonMap("val", 99L); + final ProviderDefinedType pdt = new ProviderDefinedType("com.example.Wrapper", props); + final Map map = new HashMap<>(); + map.put("key", pdt); + + final Buffer buffer = writeAndRead(map); + final Map result = reader.read(buffer); + + assertEquals(map, result); + } + + @Test(expected = IOException.class) + public void shouldThrowOnEmptyNameDuringRead() throws IOException { + // Manually write a PDT with empty name to trigger the validation + final Buffer buffer = HeapBuffer.allocate(256); + // Write type code for COMPOSITE_PDT + buffer.writeByte(0xF0); + // Write value_flag = 0 (not null) + buffer.writeByte(0x00); + // Write a fully-qualified empty string: type_code STRING (0x03), value_flag 0, length 0 + buffer.writeByte(0x03); + buffer.writeByte(0x00); + buffer.writeInt(0); + // Write a fully-qualified map: type_code MAP (0x0A), value_flag 0, length 0 + buffer.writeByte(0x0A); + buffer.writeByte(0x00); + buffer.writeInt(0); + + buffer.readerIndex(0); + reader.read(buffer); + } + + @Test(expected = IOException.class) + public void shouldThrowOnNonStringKeyInPropertiesMap() throws IOException { + final Buffer buffer = HeapBuffer.allocate(256); + // Write type code for COMPOSITE_PDT (0xF0), value_flag 0 + buffer.writeByte(0xF0); + buffer.writeByte(0x00); + // Write fully-qualified String name: type STRING (0x03), flag 0, length 4, "test" + buffer.writeByte(0x03); + buffer.writeByte(0x00); + buffer.writeInt(4); + buffer.writeBytes(new byte[]{'t', 'e', 's', 't'}); + // Write fully-qualified Map: type MAP (0x0A), flag 0, length 1 (one entry) + buffer.writeByte(0x0A); + buffer.writeByte(0x00); + buffer.writeInt(1); + // Key: INT type (0x01), flag 0, value 42 + buffer.writeByte(0x01); + buffer.writeByte(0x00); + buffer.writeInt(42); + // Value: STRING type (0x03), flag 0, length 3, "val" + buffer.writeByte(0x03); + buffer.writeByte(0x00); + buffer.writeInt(3); + buffer.writeBytes(new byte[]{'v', 'a', 'l'}); + + buffer.readerIndex(0); + reader.read(buffer); + } + + @Test + public void shouldHandleNullPdt() throws IOException { + final Buffer buffer = HeapBuffer.allocate(64); + writer.write(null, buffer); + buffer.readerIndex(0); + final Object result = reader.read(buffer); + assertNull(result); + } + + @Test + public void shouldAutoHydrateWhenRegistryConfigured() throws IOException { + final ProviderDefinedTypeRegistry pdtRegistry = ProviderDefinedTypeRegistry.empty(); + pdtRegistry.register(new ProviderDefinedTypeAdapter>() { + @Override + public String typeName() { return "com.example.Point"; } + + @Override + public Class> targetClass() { return (Class) Map.class; } + + @Override + public Map fromProperties(final Map properties) { + final Map result = new LinkedHashMap<>(properties); + result.put("hydrated", true); + return result; + } + + @Override + public Map toProperties(final Map value) { return value; } + }); + + final GraphBinaryReader hydratingReader = new GraphBinaryReader( + org.apache.tinkerpop.gremlin.structure.io.binary.TypeSerializerRegistry.INSTANCE, pdtRegistry); + + final Map props = new LinkedHashMap<>(); + props.put("x", 1); + props.put("y", 2); + final ProviderDefinedType pdt = new ProviderDefinedType("com.example.Point", props); + + final Buffer buffer = writeAndRead(pdt); + final Object result = hydratingReader.read(buffer); + + // Should be the hydrated map, not a ProviderDefinedType + assertEquals(true, ((Map) result).get("hydrated")); + assertEquals(1, ((Map) result).get("x")); + assertEquals(2, ((Map) result).get("y")); + } + + @Test + public void shouldNotHydrateWhenNoRegistryConfigured() throws IOException { + final Map props = new LinkedHashMap<>(); + props.put("x", 1); + final ProviderDefinedType pdt = new ProviderDefinedType("com.example.Point", props); + + final Buffer buffer = writeAndRead(pdt); + final ProviderDefinedType result = reader.read(buffer); + + assertEquals(pdt, result); + } +} diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/graphson/PdtGraphSONSerializersV4Test.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/graphson/PdtGraphSONSerializersV4Test.java new file mode 100644 index 00000000000..67786dbcd85 --- /dev/null +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/graphson/PdtGraphSONSerializersV4Test.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.graphson; + +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeAdapter; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; +import org.apache.tinkerpop.shaded.jackson.databind.JsonNode; +import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper; +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link PdtGraphSONSerializersV4}. + */ +public class PdtGraphSONSerializersV4Test extends AbstractGraphSONTest { + + private ObjectMapper mapper; + private ObjectMapper plainMapper; + + @Before + public void setUp() { + mapper = GraphSONMapper.build() + .version(GraphSONVersion.V4_0) + .addCustomModule(GraphSONXModuleV4.build()) + .typeInfo(TypeInfo.PARTIAL_TYPES) + .create().createMapper(); + plainMapper = new ObjectMapper(); + } + + @Test + public void shouldSerializeSimplePdt() throws Exception { + final Map props = new LinkedHashMap<>(); + props.put("x", 1); + props.put("y", 2); + final ProviderDefinedType pdt = new ProviderDefinedType("Point", props); + + final String json = mapper.writeValueAsString(pdt); + final JsonNode node = plainMapper.readTree(json); + + assertEquals("g:CompositePdt", node.get("@type").asText()); + final JsonNode value = node.get("@value"); + assertEquals("Point", value.get("type").asText()); + + final JsonNode fields = value.get("fields"); + assertEquals("g:Int32", fields.get("x").get("@type").asText()); + assertEquals(1, fields.get("x").get("@value").asInt()); + assertEquals("g:Int32", fields.get("y").get("@type").asText()); + assertEquals(2, fields.get("y").get("@value").asInt()); + } + + @Test + public void shouldDeserializeValidJson() throws Exception { + final String json = "{\"@type\":\"g:CompositePdt\",\"@value\":{\"type\":\"Point\",\"fields\":{\"x\":{\"@type\":\"g:Int32\",\"@value\":1},\"y\":{\"@type\":\"g:Int32\",\"@value\":2}}}}"; + final ProviderDefinedType pdt = mapper.readValue(json, ProviderDefinedType.class); + + assertEquals("Point", pdt.getName()); + assertEquals(2, pdt.getProperties().size()); + assertEquals(1, pdt.getProperties().get("x")); + assertEquals(2, pdt.getProperties().get("y")); + } + + @Test + public void shouldRoundTrip() throws Exception { + final Map props = new LinkedHashMap<>(); + props.put("x", 1); + props.put("y", 2); + final ProviderDefinedType original = new ProviderDefinedType("Point", props); + + final ProviderDefinedType result = serializeDeserialize(mapper, original, ProviderDefinedType.class); + + assertEquals(original.getName(), result.getName()); + assertEquals(original.getProperties(), result.getProperties()); + } + + @Test + public void shouldSerializeNestedPdt() throws Exception { + final Map innerProps = new LinkedHashMap<>(); + innerProps.put("x", 10); + innerProps.put("y", 20); + final ProviderDefinedType inner = new ProviderDefinedType("Point", innerProps); + + final Map outerProps = new LinkedHashMap<>(); + outerProps.put("name", "origin"); + outerProps.put("location", inner); + final ProviderDefinedType outer = new ProviderDefinedType("NamedPoint", outerProps); + + final String json = mapper.writeValueAsString(outer); + final JsonNode node = plainMapper.readTree(json); + + assertEquals("g:CompositePdt", node.get("@type").asText()); + final JsonNode fields = node.get("@value").get("fields"); + final JsonNode locationNode = fields.get("location"); + assertEquals("g:CompositePdt", locationNode.get("@type").asText()); + assertEquals("Point", locationNode.get("@value").get("type").asText()); + + // round-trip nested + final ProviderDefinedType result = serializeDeserialize(mapper, outer, ProviderDefinedType.class); + assertEquals("NamedPoint", result.getName()); + assertTrue(result.getProperties().get("location") instanceof ProviderDefinedType); + final ProviderDefinedType nestedResult = (ProviderDefinedType) result.getProperties().get("location"); + assertEquals("Point", nestedResult.getName()); + assertEquals(10, nestedResult.getProperties().get("x")); + assertEquals(20, nestedResult.getProperties().get("y")); + } + + @Test + public void shouldHandleNullFieldValues() throws Exception { + final Map props = new LinkedHashMap<>(); + props.put("name", "test"); + props.put("value", null); + final ProviderDefinedType pdt = new ProviderDefinedType("NullableType", props); + + final ProviderDefinedType result = serializeDeserialize(mapper, pdt, ProviderDefinedType.class); + + assertEquals("NullableType", result.getName()); + assertEquals("test", result.getProperties().get("name")); + assertNull(result.getProperties().get("value")); + assertTrue(result.getProperties().containsKey("value")); + } + + // --- Hydration tests --- + + static class Point { + final int x; + final int y; + Point(int x, int y) { this.x = x; this.y = y; } + } + + static class PointAdapter implements ProviderDefinedTypeAdapter { + @Override public String typeName() { return "Point"; } + @Override public Class targetClass() { return Point.class; } + @Override public Map toProperties(Point obj) { + final Map m = new HashMap<>(); + m.put("x", obj.x); + m.put("y", obj.y); + return m; + } + @Override public Point fromProperties(Map properties) { + return new Point((int) properties.get("x"), (int) properties.get("y")); + } + } + + @Test + public void shouldHydrateWhenRegistryConfigured() throws Exception { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new PointAdapter()); + + final ObjectMapper hydratingMapper = GraphSONMapper.build() + .version(GraphSONVersion.V4_0) + .addCustomModule(GraphSONXModuleV4.build()) + .typeInfo(TypeInfo.PARTIAL_TYPES) + .pdtRegistry(registry) + .create().createMapper(); + + final Map props = new LinkedHashMap<>(); + props.put("x", 3); + props.put("y", 7); + final ProviderDefinedType pdt = new ProviderDefinedType("Point", props); + + final ProviderDefinedType result = serializeDeserialize(hydratingMapper, pdt, ProviderDefinedType.class); + + assertNotNull(result.getHydrated()); + assertTrue(result.getHydrated() instanceof Point); + assertEquals(3, ((Point) result.getHydrated()).x); + assertEquals(7, ((Point) result.getHydrated()).y); + } + + @Test + public void shouldNotHydrateWhenNoRegistryConfigured() throws Exception { + final Map props = new LinkedHashMap<>(); + props.put("x", 1); + props.put("y", 2); + final ProviderDefinedType pdt = new ProviderDefinedType("Point", props); + + final ProviderDefinedType result = serializeDeserialize(mapper, pdt, ProviderDefinedType.class); + + assertNull(result.getHydrated()); + assertEquals("Point", result.getName()); + assertEquals(1, result.getProperties().get("x")); + } + + @Test + public void shouldReturnRawPdtWhenTypeNotRegistered() throws Exception { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + // No adapter registered for "Unknown" + + final ObjectMapper hydratingMapper = GraphSONMapper.build() + .version(GraphSONVersion.V4_0) + .addCustomModule(GraphSONXModuleV4.build()) + .typeInfo(TypeInfo.PARTIAL_TYPES) + .pdtRegistry(registry) + .create().createMapper(); + + final Map props = new LinkedHashMap<>(); + props.put("a", 1); + final ProviderDefinedType pdt = new ProviderDefinedType("Unknown", props); + + final ProviderDefinedType result = serializeDeserialize(hydratingMapper, pdt, ProviderDefinedType.class); + + assertNull(result.getHydrated()); + assertEquals("Unknown", result.getName()); + assertEquals(1, result.getProperties().get("a")); + } +} diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeRegistryTest.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeRegistryTest.java new file mode 100644 index 00000000000..13f9a05b59e --- /dev/null +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeRegistryTest.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.pdt; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class ProviderDefinedTypeRegistryTest { + + // Simple test type + static class Point { + final int x; + final int y; + Point(int x, int y) { this.x = x; this.y = y; } + } + + static class PointAdapter implements ProviderDefinedTypeAdapter { + @Override public String typeName() { return "Point"; } + @Override public Class targetClass() { return Point.class; } + @Override public Map toProperties(Point obj) { + final Map m = new HashMap<>(); + m.put("x", obj.x); + m.put("y", obj.y); + return m; + } + @Override public Point fromProperties(Map properties) { + return new Point((int) properties.get("x"), (int) properties.get("y")); + } + } + + // Nested test type + static class Line { + final Point start; + final Point end; + Line(Point start, Point end) { this.start = start; this.end = end; } + } + + static class LineAdapter implements ProviderDefinedTypeAdapter { + @Override public String typeName() { return "Line"; } + @Override public Class targetClass() { return Line.class; } + @Override public Map toProperties(Line obj) { + final Map m = new HashMap<>(); + m.put("start", obj.start); + m.put("end", obj.end); + return m; + } + @Override public Line fromProperties(Map properties) { + return new Line((Point) properties.get("start"), (Point) properties.get("end")); + } + } + + // Adapter that always throws + static class FailingAdapter implements ProviderDefinedTypeAdapter { + @Override public String typeName() { return "Failing"; } + @Override public Class targetClass() { return Point.class; } + @Override public Map toProperties(Point obj) { return new HashMap<>(); } + @Override public Point fromProperties(Map properties) { + throw new RuntimeException("intentional failure"); + } + } + + @Test + public void shouldHydrateSimplePdt() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new PointAdapter()); + + final Map props = new HashMap<>(); + props.put("x", 3); + props.put("y", 7); + final ProviderDefinedType pdt = new ProviderDefinedType("Point", props); + + final Object result = registry.hydrate(pdt); + assertTrue(result instanceof Point); + assertEquals(3, ((Point) result).x); + assertEquals(7, ((Point) result).y); + } + + @Test + public void shouldReturnRawPdtWhenNoAdapterRegistered() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + + final Map props = new HashMap<>(); + props.put("x", 1); + final ProviderDefinedType pdt = new ProviderDefinedType("Unknown", props); + + final Object result = registry.hydrate(pdt); + assertSame(pdt, result); + } + + @Test + public void shouldHydrateNestedPdts() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new PointAdapter()); + registry.register(new LineAdapter()); + + final Map startProps = new HashMap<>(); + startProps.put("x", 0); + startProps.put("y", 0); + final Map endProps = new HashMap<>(); + endProps.put("x", 5); + endProps.put("y", 5); + + final Map lineProps = new HashMap<>(); + lineProps.put("start", new ProviderDefinedType("Point", startProps)); + lineProps.put("end", new ProviderDefinedType("Point", endProps)); + final ProviderDefinedType linePdt = new ProviderDefinedType("Line", lineProps); + + final Object result = registry.hydrate(linePdt); + assertTrue(result instanceof Line); + final Line line = (Line) result; + assertEquals(0, line.start.x); + assertEquals(0, line.start.y); + assertEquals(5, line.end.x); + assertEquals(5, line.end.y); + } + + @Test + public void shouldPartiallyHydrateWhenInnerAdapterMissing() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new LineAdapter()); + // Point adapter NOT registered + + final Map startProps = new HashMap<>(); + startProps.put("x", 1); + startProps.put("y", 2); + final ProviderDefinedType startPdt = new ProviderDefinedType("Point", startProps); + + final Map endProps = new HashMap<>(); + endProps.put("x", 3); + endProps.put("y", 4); + final ProviderDefinedType endPdt = new ProviderDefinedType("Point", endProps); + + final Map lineProps = new HashMap<>(); + lineProps.put("start", startPdt); + lineProps.put("end", endPdt); + final ProviderDefinedType linePdt = new ProviderDefinedType("Line", lineProps); + + // Line adapter will receive ProviderDefinedType values for start/end since Point is not registered. + // The LineAdapter.fromProperties casts to Point which will throw ClassCastException, + // so hydrate should fall back to returning the raw PDT. + final Object result = registry.hydrate(linePdt); + assertSame(linePdt, result); + } + + @Test + public void shouldFallBackWhenAdapterThrows() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new FailingAdapter()); + + final Map props = new HashMap<>(); + props.put("x", 1); + final ProviderDefinedType pdt = new ProviderDefinedType("Failing", props); + + // should not throw, should return raw PDT + final Object result = registry.hydrate(pdt); + assertSame(pdt, result); + } + + @Test + public void shouldLookUpAdapterByClass() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + final PointAdapter adapter = new PointAdapter(); + registry.register(adapter); + + final Optional> found = registry.getAdapterByClass(Point.class); + assertTrue(found.isPresent()); + assertEquals("Point", found.get().typeName()); + } + + // Collection test type + static class Polygon { + final List vertices; + Polygon(List vertices) { this.vertices = vertices; } + } + + static class PolygonAdapter implements ProviderDefinedTypeAdapter { + @Override public String typeName() { return "Polygon"; } + @Override public Class targetClass() { return Polygon.class; } + @Override public Map toProperties(Polygon obj) { + final Map m = new HashMap<>(); + m.put("vertices", obj.vertices); + return m; + } + @SuppressWarnings("unchecked") + @Override public Polygon fromProperties(Map properties) { + return new Polygon((List) properties.get("vertices")); + } + } + + @Test + public void shouldHydratePdtsInsideList() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new PointAdapter()); + registry.register(new PolygonAdapter()); + + final Map p1 = new HashMap<>(); + p1.put("x", 1); p1.put("y", 2); + final Map p2 = new HashMap<>(); + p2.put("x", 3); p2.put("y", 4); + + final Map polyProps = new HashMap<>(); + polyProps.put("vertices", Arrays.asList( + new ProviderDefinedType("Point", p1), + new ProviderDefinedType("Point", p2))); + final ProviderDefinedType polyPdt = new ProviderDefinedType("Polygon", polyProps); + + final Object result = registry.hydrate(polyPdt); + assertTrue(result instanceof Polygon); + final Polygon polygon = (Polygon) result; + assertEquals(2, polygon.vertices.size()); + assertEquals(1, polygon.vertices.get(0).x); + assertEquals(2, polygon.vertices.get(0).y); + assertEquals(3, polygon.vertices.get(1).x); + assertEquals(4, polygon.vertices.get(1).y); + } + + @Test + public void shouldHydratePdtsInsideMapValues() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new PointAdapter()); + + // A simple adapter that receives a map of named points + registry.register(new ProviderDefinedTypeAdapter() { + @Override public String typeName() { return "PointMap"; } + @Override public Class targetClass() { return Map.class; } + @Override public Map toProperties(Map obj) { return new HashMap<>(); } + @SuppressWarnings("unchecked") + @Override public Map fromProperties(Map properties) { + return (Map) properties.get("points"); + } + }); + + final Map p1 = new HashMap<>(); + p1.put("x", 10); p1.put("y", 20); + final Map p2 = new HashMap<>(); + p2.put("x", 30); p2.put("y", 40); + + final Map innerMap = new HashMap<>(); + innerMap.put("origin", new ProviderDefinedType("Point", p1)); + innerMap.put("target", new ProviderDefinedType("Point", p2)); + + final Map props = new HashMap<>(); + props.put("points", innerMap); + final ProviderDefinedType pdt = new ProviderDefinedType("PointMap", props); + + final Object result = registry.hydrate(pdt); + assertTrue(result instanceof Map); + @SuppressWarnings("unchecked") + final Map resultMap = (Map) result; + assertTrue(resultMap.get("origin") instanceof Point); + assertTrue(resultMap.get("target") instanceof Point); + assertEquals(10, ((Point) resultMap.get("origin")).x); + assertEquals(40, ((Point) resultMap.get("target")).y); + } + + @Test + public void shouldBuildViaServiceLoader() { + // ServiceLoader.load will find adapters on the classpath. With no META-INF/services file + // in test scope, this should produce an empty registry that still functions. + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.build(); + + final Map props = new HashMap<>(); + props.put("x", 1); + final ProviderDefinedType pdt = new ProviderDefinedType("Unregistered", props); + final Object result = registry.hydrate(pdt); + assertSame(pdt, result); + } + + // Annotated test types for register(Class...) + @ProviderDefined(name = "AnnotatedPoint") + static class AnnotatedPoint { + public int x; + public int y; + public AnnotatedPoint() {} + public AnnotatedPoint(int x, int y) { this.x = x; this.y = y; } + } + + @ProviderDefined(name = "Excluded", excludedFields = {"secret"}) + static class ExcludedFields { + public int value; + public String secret; + public ExcludedFields() {} + } + + @ProviderDefined(name = "NoCtor") + static class NoNoArgCtor { + public int x; + public NoNoArgCtor(int x) { this.x = x; } + } + + static class NotAnnotated { + public int x; + } + + @Test + public void shouldRegisterAndHydrateAnnotatedClass() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(AnnotatedPoint.class); + + final Map props = new HashMap<>(); + props.put("x", 3); + props.put("y", 7); + final Object result = registry.hydrate(new ProviderDefinedType("AnnotatedPoint", props)); + + assertTrue(result instanceof AnnotatedPoint); + assertEquals(3, ((AnnotatedPoint) result).x); + assertEquals(7, ((AnnotatedPoint) result).y); + } + + @Test + public void shouldDehydrateAnnotatedClassViaAdapter() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(AnnotatedPoint.class); + + final Optional> adapter = registry.getAdapterByClass(AnnotatedPoint.class); + assertTrue(adapter.isPresent()); + assertEquals("AnnotatedPoint", adapter.get().typeName()); + } + + @Test + public void shouldRespectExcludedFieldsWhenHydratingAnnotatedClass() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(ExcludedFields.class); + + final Map props = new HashMap<>(); + props.put("value", 42); + props.put("secret", "should-be-ignored"); + final Object result = registry.hydrate(new ProviderDefinedType("Excluded", props)); + + assertTrue(result instanceof ExcludedFields); + assertEquals(42, ((ExcludedFields) result).value); + // secret is excluded from the field mapping, so it is not set + assertEquals(null, ((ExcludedFields) result).secret); + } + + @Test + public void shouldThrowWhenRegisteringNonAnnotatedClass() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + try { + registry.register(NotAnnotated.class); + fail("Expected IllegalArgumentException for non-annotated class"); + } catch (final IllegalArgumentException e) { + assertTrue(e.getMessage().contains("not annotated with @ProviderDefined")); + } + } + + @Test + public void shouldThrowWhenRegisteringClassWithoutNoArgConstructor() { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + try { + registry.register(NoNoArgCtor.class); + fail("Expected IllegalArgumentException for class without no-arg constructor"); + } catch (final IllegalArgumentException e) { + assertTrue(e.getMessage().contains("no-arg constructor")); + } + } +} diff --git a/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeTest.java b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeTest.java new file mode 100644 index 00000000000..2c153700c8c --- /dev/null +++ b/gremlin-core/src/test/java/org/apache/tinkerpop/gremlin/structure/io/pdt/ProviderDefinedTypeTest.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.structure.io.pdt; + +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; + +public class ProviderDefinedTypeTest { + + @ProviderDefined + static class Point { + int x = 1; + int y = 2; + } + + @ProviderDefined(name = "GeoPoint") + static class NamedPoint { + double lat = 45.0; + double lon = -93.0; + } + + @ProviderDefined(includedFields = {"x"}) + static class IncludedFieldsPoint { + int x = 10; + int y = 20; + int z = 30; + } + + @ProviderDefined(excludedFields = {"z"}) + static class ExcludedFieldsPoint { + int x = 10; + int y = 20; + int z = 30; + } + + @ProviderDefined(includedFields = {"x"}, excludedFields = {"z"}) + static class ConflictingFieldsPoint { + int x = 10; + int y = 20; + int z = 30; + } + + @ProviderDefined + static class NullFieldPoint { + String label = null; + int x = 5; + } + + static class NotAnnotated { + int value = 1; + } + + static class BasePoint { + int x = 1; + int y = 2; + } + + @ProviderDefined(name = "GeoPoint") + static class InheritedPoint extends BasePoint { + String label = "origin"; + } + + @ProviderDefined(excludedFields = {"y"}) + static class InheritedExcluded extends BasePoint { + String label = "test"; + } + + @ProviderDefined(includedFields = {"x", "label"}) + static class InheritedIncluded extends BasePoint { + String label = "included"; + } + + @Test + public void shouldConstructDirectly() { + final Map props = new HashMap<>(); + props.put("x", 1); + props.put("y", 2); + final ProviderDefinedType pdt = new ProviderDefinedType("Point", props); + assertEquals("Point", pdt.getName()); + assertEquals(props, pdt.getProperties()); + } + + @Test + public void shouldBeImmutableFromInputMap() { + final Map props = new HashMap<>(); + props.put("x", 1); + final ProviderDefinedType pdt = new ProviderDefinedType("Point", props); + props.put("y", 2); + assertEquals(1, pdt.getProperties().size()); + } + + @Test + public void shouldReturnUnmodifiableProperties() { + final Map props = new HashMap<>(); + props.put("x", 1); + final ProviderDefinedType pdt = new ProviderDefinedType("Point", props); + assertThrows(UnsupportedOperationException.class, () -> pdt.getProperties().put("y", 2)); + } + + @Test + public void shouldCreateFromAnnotatedObject() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new Point()); + assertEquals("Point", pdt.getName()); + assertEquals(1, pdt.getProperties().get("x")); + assertEquals(2, pdt.getProperties().get("y")); + } + + @Test + public void shouldUseCustomNameFromAnnotation() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new NamedPoint()); + assertEquals("GeoPoint", pdt.getName()); + } + + @Test + public void shouldFilterWithIncludedFields() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new IncludedFieldsPoint()); + assertEquals(1, pdt.getProperties().size()); + assertEquals(10, pdt.getProperties().get("x")); + } + + @Test + public void shouldFilterWithExcludedFields() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new ExcludedFieldsPoint()); + assertEquals(2, pdt.getProperties().size()); + assertEquals(10, pdt.getProperties().get("x")); + assertEquals(20, pdt.getProperties().get("y")); + } + + @Test + public void shouldThrowOnNullObject() { + assertThrows(IllegalArgumentException.class, () -> ProviderDefinedType.from(null)); + } + + @Test + public void shouldThrowOnNonAnnotatedObject() { + assertThrows(IllegalArgumentException.class, () -> ProviderDefinedType.from(new NotAnnotated())); + } + + @Test + public void shouldHaveCorrectEqualsAndHashCode() { + final Map props = new HashMap<>(); + props.put("x", 1); + final ProviderDefinedType a = new ProviderDefinedType("Point", props); + final ProviderDefinedType b = new ProviderDefinedType("Point", props); + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + + final ProviderDefinedType c = new ProviderDefinedType("Other", props); + assertNotEquals(a, c); + } + + @Test + public void shouldThrowOnNullName() { + assertThrows(IllegalArgumentException.class, () -> new ProviderDefinedType(null, new HashMap<>())); + } + + @Test + public void shouldThrowOnEmptyName() { + assertThrows(IllegalArgumentException.class, () -> new ProviderDefinedType("", new HashMap<>())); + } + + @Test + public void shouldThrowOnNullProperties() { + assertThrows(IllegalArgumentException.class, () -> new ProviderDefinedType("Point", null)); + } + + @Test + public void shouldPreserveNullFieldValues() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new NullFieldPoint()); + assertEquals(2, pdt.getProperties().size()); + assertEquals(null, pdt.getProperties().get("label")); + assertEquals(5, pdt.getProperties().get("x")); + } + + @Test + public void shouldThrowOnConflictingIncludedAndExcludedFields() { + assertThrows(IllegalArgumentException.class, () -> ProviderDefinedType.from(new ConflictingFieldsPoint())); + } + + @Test + public void shouldIncludeInheritedFields() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new InheritedPoint()); + assertEquals("GeoPoint", pdt.getName()); + assertEquals(3, pdt.getProperties().size()); + assertEquals("origin", pdt.getProperties().get("label")); + assertEquals(1, pdt.getProperties().get("x")); + assertEquals(2, pdt.getProperties().get("y")); + } + + @Test + public void shouldExcludeInheritedFields() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new InheritedExcluded()); + assertEquals(2, pdt.getProperties().size()); + assertEquals("test", pdt.getProperties().get("label")); + assertEquals(1, pdt.getProperties().get("x")); + } + + @Test + public void shouldIncludeOnlySpecifiedFieldsAcrossHierarchy() { + final ProviderDefinedType pdt = ProviderDefinedType.from(new InheritedIncluded()); + assertEquals(2, pdt.getProperties().size()); + assertEquals("included", pdt.getProperties().get("label")); + assertEquals(1, pdt.getProperties().get("x")); + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs index 50dcbcbd274..4d65eead89e 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs @@ -26,6 +26,7 @@ using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; +using Gremlin.Net.Structure; using Gremlin.Net.Structure.IO.GraphBinary4; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -63,15 +64,26 @@ public class GremlinClient : IGremlinClient /// and can modify headers, body, URI, and method /// before the request is sent. /// + /// + /// An optional for automatic hydration of + /// provider-defined types. + /// public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? requestSerializer, IMessageSerializer responseSerializer, ConnectionSettings? connectionSettings = null, ILoggerFactory? loggerFactory = null, - IReadOnlyList>? interceptors = null) + IReadOnlyList>? interceptors = null, + ProviderDefinedTypeRegistry? pdtRegistry = null) { connectionSettings ??= new ConnectionSettings(); LoggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + if (pdtRegistry != null) + { + requestSerializer?.SetPdtRegistry(pdtRegistry); + responseSerializer.SetPdtRegistry(pdtRegistry); + } + _connection = new Connection( gremlinServer.Uri, requestSerializer, @@ -98,14 +110,19 @@ public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? requestSer /// /// An optional list of request interceptors. /// + /// + /// An optional for automatic hydration of + /// provider-defined types. + /// public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? messageSerializer = null, ConnectionSettings? connectionSettings = null, ILoggerFactory? loggerFactory = null, - IReadOnlyList>? interceptors = null) + IReadOnlyList>? interceptors = null, + ProviderDefinedTypeRegistry? pdtRegistry = null) : this(gremlinServer, messageSerializer ?? new GraphBinary4MessageSerializer(), messageSerializer ?? new GraphBinary4MessageSerializer(), - connectionSettings, loggerFactory, interceptors) + connectionSettings, loggerFactory, interceptors, pdtRegistry) { } diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs index 9abb59dc661..831a8ee379a 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs @@ -26,6 +26,7 @@ using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; +using Gremlin.Net.Structure; namespace Gremlin.Net.Driver { @@ -61,5 +62,12 @@ Task SerializeMessageAsync(RequestMessage requestMessage, /// An async sequence of deserialized result objects. IAsyncEnumerable DeserializeMessageAsync(Stream stream, CancellationToken cancellationToken = default); + + /// + /// Sets the for automatic hydration + /// of provider-defined types during deserialization. The default implementation + /// is a no-op for serializers that do not support PDT hydration. + /// + void SetPdtRegistry(ProviderDefinedTypeRegistry pdtRegistry) { } } } \ No newline at end of file diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs index 3340730c1b9..32d7ce01a77 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs @@ -29,6 +29,7 @@ using Gremlin.Net.Process.Remote; using Gremlin.Net.Process.Traversal; using Gremlin.Net.Process.Traversal.Strategy.Decoration; +using Gremlin.Net.Structure; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -43,6 +44,11 @@ public class DriverRemoteConnection : IRemoteConnection, IDisposable private readonly string _traversalSource; private readonly ILogger _logger; + /// + /// Gets or sets the for registry-based dehydration. + /// + public ProviderDefinedTypeRegistry? PdtRegistry { get; set; } + // All OptionsStrategy keys are passed through to the request fields. // The server filters out options that don't apply, and this allows // providers to use custom request fields via the Client directly or DRC. @@ -61,14 +67,17 @@ public class DriverRemoteConnection : IRemoteConnection, IDisposable /// An optional list of request interceptors forwarded to the underlying /// . /// + /// An optional registry for PDT hydration. /// Thrown when client is null. public DriverRemoteConnection(string host, int port, string traversalSource = "g", ILoggerFactory? loggerFactory = null, - IReadOnlyList>? interceptors = null) : this( - new GremlinClient(new GremlinServer(host, port), loggerFactory: loggerFactory, interceptors: interceptors), + IReadOnlyList>? interceptors = null, + ProviderDefinedTypeRegistry? pdtRegistry = null) : this( + new GremlinClient(new GremlinServer(host, port), loggerFactory: loggerFactory, interceptors: interceptors, pdtRegistry: pdtRegistry), traversalSource, logger: loggerFactory?.CreateLogger() ?? NullLogger.Instance) { + PdtRegistry = pdtRegistry; } /// @@ -76,10 +85,13 @@ public DriverRemoteConnection(string host, int port, string traversalSource = "g /// /// The that will be used for the connection. /// The name of the traversal source on the server to bind to. + /// An optional registry for PDT hydration. /// Thrown when client or the traversalSource is null. - public DriverRemoteConnection(IGremlinClient client, string traversalSource = "g") + public DriverRemoteConnection(IGremlinClient client, string traversalSource = "g", + ProviderDefinedTypeRegistry? pdtRegistry = null) : this(client, traversalSource, logger: null) { + PdtRegistry = pdtRegistry; } private DriverRemoteConnection(IGremlinClient client, string traversalSource, diff --git a/gremlin-dotnet/src/Gremlin.Net/Process/Remote/IRemoteConnection.cs b/gremlin-dotnet/src/Gremlin.Net/Process/Remote/IRemoteConnection.cs index a571c9546b7..89b37c35d8b 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Process/Remote/IRemoteConnection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Process/Remote/IRemoteConnection.cs @@ -24,6 +24,7 @@ using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Process.Traversal; +using Gremlin.Net.Structure; namespace Gremlin.Net.Process.Remote { @@ -56,5 +57,10 @@ public interface IRemoteConnection /// Determines if the connection is bound to a session. /// bool IsSessionBound { get; } + + /// + /// Gets the for registry-based dehydration, or null. + /// + ProviderDefinedTypeRegistry? PdtRegistry => null; } } \ No newline at end of file diff --git a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/AnonymousTraversalSource.cs b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/AnonymousTraversalSource.cs index aff2d4c2d17..fa22be61993 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/AnonymousTraversalSource.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/AnonymousTraversalSource.cs @@ -58,7 +58,7 @@ public static AnonymousTraversalSource Traversal() /// A configured to use the provided . public GraphTraversalSource With(IRemoteConnection remoteConnection) => new GraphTraversalSource(new List(), - new GremlinLang(), remoteConnection); + new GremlinLang { PdtRegistry = remoteConnection?.PdtRegistry }, remoteConnection!); } } \ No newline at end of file diff --git a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/GremlinLang.cs b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/GremlinLang.cs index 21115a8f8ff..fff3dcff881 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/GremlinLang.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Process/Traversal/GremlinLang.cs @@ -23,9 +23,11 @@ using System; using System.Collections; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Globalization; using System.Numerics; +using System.Reflection; using System.Text; using Gremlin.Net.Process.Traversal.Strategy; using Gremlin.Net.Process.Traversal.Strategy.Decoration; @@ -40,11 +42,17 @@ namespace Gremlin.Net.Process.Traversal public class GremlinLang : ICloneable, IEquatable { private static readonly object[] EmptyArray = Array.Empty(); + private static readonly ConcurrentDictionary _pdtCache = new(); private StringBuilder _gremlin = new(); private Dictionary _parameters = new(); private List _optionsStrategies = new(); + /// + /// Gets or sets the for registry-based dehydration. + /// + public ProviderDefinedTypeRegistry? PdtRegistry { get; set; } + /// /// Initializes a new instance of the class. /// @@ -329,6 +337,25 @@ private string ArgAsString(object? arg) if (arg is CardinalityValue cv) return $"Cardinality.{cv.Cardinality!.EnumValue}({ArgAsString(cv.Value)})"; + if (arg is ProviderDefinedType pdt) + { + var sb2 = new StringBuilder("["); + var count = pdt.Properties.Count; + if (count == 0) + { + sb2.Append(':'); + } + else + { + foreach (var kvp in pdt.Properties) + { + sb2.Append(ArgAsString(kvp.Key)).Append(':').Append(ArgAsString(kvp.Value)); + if (--count > 0) sb2.Append(','); + } + } + sb2.Append(']'); + return $"PDT(\"{EscapeJava(pdt.Name)}\",{sb2})"; + } if (arg is IDictionary dict) return AsString(dict); @@ -348,6 +375,33 @@ private string ArgAsString(object? arg) if (arg is Type type) return type.Name; + // Registry-based dehydration + if (PdtRegistry != null) + { + var adapterInfo = PdtRegistry.GetAdapterByType(arg.GetType()); + if (adapterInfo != null) + { + var (adapterTypeName, toProperties) = adapterInfo.Value; + var props = toProperties(arg); + return ArgAsString(new ProviderDefinedType(adapterTypeName, + new Dictionary(props))); + } + } + + // Auto-dehydrate objects annotated with [ProviderDefined] + var cached = GetPdtInfo(arg.GetType()); + if (cached != null) + { + var (typeName, fields) = cached.Value; + ProviderDefinedAttribute.RegisteredTypes.TryAdd(typeName, arg.GetType()); + var props = new Dictionary(); + foreach (var field in fields) + { + props[field.Name] = field.GetValue(arg); + } + return ArgAsString(new ProviderDefinedType(typeName, props)); + } + throw new ArgumentException( $"GremlinLang contains at least one type [{arg.GetType().Name}] that cannot be represented as text."); } @@ -728,6 +782,39 @@ private static bool IsValidIdentifier(string name) return true; } + private static (string name, PropertyInfo[] props)? GetPdtInfo(Type type) + { + return _pdtCache.GetOrAdd(type, t => + { + var attrs = t.GetCustomAttributes(typeof(ProviderDefinedAttribute), false); + if (attrs.Length == 0) return null; + + var attr = (ProviderDefinedAttribute)attrs[0]; + var typeName = string.IsNullOrEmpty(attr.Name) ? t.Name : attr.Name; + + var included = attr.IncludedFields; + var excluded = attr.ExcludedFields; + if (included is { Length: > 0 } && excluded is { Length: > 0 }) + { + throw new ArgumentException( + "[ProviderDefined] cannot specify both IncludedFields and ExcludedFields"); + } + + var includedSet = included is { Length: > 0 } ? new HashSet(included) : null; + var excludedSet = excluded is { Length: > 0 } ? new HashSet(excluded) : null; + + var allProps = t.GetProperties(BindingFlags.Public | BindingFlags.Instance); + var filtered = new List(); + foreach (var p in allProps) + { + if (includedSet != null && !includedSet.Contains(p.Name)) continue; + if (excludedSet != null && excludedSet.Contains(p.Name)) continue; + filtered.Add(p); + } + return (typeName, filtered.ToArray()); + }); + } + /// /// Creates a deep copy of this instance. /// @@ -738,7 +825,8 @@ public GremlinLang Clone() { _gremlin = new StringBuilder(_gremlin.ToString()), _parameters = new Dictionary(_parameters), - _optionsStrategies = new List(_optionsStrategies) + _optionsStrategies = new List(_optionsStrategies), + PdtRegistry = PdtRegistry }; return clone; } diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/DataType.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/DataType.cs index 9490a34302e..de4445e1e15 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/DataType.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/DataType.cs @@ -59,8 +59,8 @@ public class DataType : IEquatable // Not yet implemented // public static readonly DataType Tree = new DataType(0x2B); public static readonly DataType Merge = new DataType(0x2E); + public static readonly DataType CompositePDT = new DataType(0xF0); // Not yet implemented - // public static readonly DataType CompositePDT = new DataType(0xF0); // public static readonly DataType PrimitivePDT = new DataType(0xF1); public static readonly DataType Char = new DataType(0x80); public static readonly DataType Duration = new DataType(0x81); diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs index 706ab47dd01..d336ebfab83 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs @@ -36,7 +36,7 @@ namespace Gremlin.Net.Structure.IO.GraphBinary4 /// public class GraphBinary4MessageSerializer : IMessageSerializer { - private readonly GraphBinaryReader _reader; + private GraphBinaryReader _reader; private readonly GraphBinaryWriter _writer; private readonly RequestMessageSerializer _requestSerializer = new RequestMessageSerializer(); private readonly ResponseSerializer _responseSerializer = new ResponseSerializer(); @@ -54,6 +54,25 @@ public GraphBinary4MessageSerializer() _writer = new GraphBinaryWriter(); } + /// + /// Initializes a new instance of the class + /// with a for automatic hydration. + /// + public GraphBinary4MessageSerializer(ProviderDefinedTypeRegistry pdtRegistry) + { + _reader = new GraphBinaryReader(pdtRegistry: pdtRegistry); + _writer = new GraphBinaryWriter(); + } + + /// + /// Sets the on this serializer's reader + /// for automatic hydration of provider-defined types. + /// + public void SetPdtRegistry(ProviderDefinedTypeRegistry pdtRegistry) + { + _reader = new GraphBinaryReader(pdtRegistry: pdtRegistry); + } + /// public async Task SerializeMessageAsync(RequestMessage requestMessage, CancellationToken cancellationToken = default) diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinaryReader.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinaryReader.cs index ffe9f08f97a..97f0dd76fed 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinaryReader.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinaryReader.cs @@ -24,6 +24,7 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using Gremlin.Net.Structure; namespace Gremlin.Net.Structure.IO.GraphBinary4 { @@ -33,14 +34,17 @@ namespace Gremlin.Net.Structure.IO.GraphBinary4 public class GraphBinaryReader { private readonly TypeSerializerRegistry _registry; + private readonly ProviderDefinedTypeRegistry? _pdtRegistry; /// /// Initializes a new instance of the class. /// /// The to use for deserialization. - public GraphBinaryReader(TypeSerializerRegistry? registry = null) + /// Optional for automatic hydration. + public GraphBinaryReader(TypeSerializerRegistry? registry = null, ProviderDefinedTypeRegistry? pdtRegistry = null) { _registry = registry ?? TypeSerializerRegistry.Instance; + _pdtRegistry = pdtRegistry; } /// @@ -90,7 +94,18 @@ public async Task ReadNonNullableValueAsync(Stream stream, } var typeSerializer = _registry.GetSerializerFor(type); - return await typeSerializer.ReadAsync(stream, this, cancellationToken).ConfigureAwait(false); + var result = await typeSerializer.ReadAsync(stream, this, cancellationToken).ConfigureAwait(false); + if (result is ProviderDefinedType pdt) + { + if (_pdtRegistry != null) + { + var hydrated = _pdtRegistry.Hydrate(pdt); + if (hydrated is not ProviderDefinedType) + return hydrated; + } + return ProviderDefinedAttribute.HydrateIfRegistered(pdt); + } + return result; } } } \ No newline at end of file diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/TypeSerializerRegistry.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/TypeSerializerRegistry.cs index 82421f14a13..cb9667db624 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/TypeSerializerRegistry.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/TypeSerializerRegistry.cs @@ -64,6 +64,7 @@ public class TypeSerializerRegistry {typeof(char), new CharSerializer()}, {typeof(TimeSpan), new DurationSerializer()}, {typeof(Marker), SingleTypeSerializers.MarkerSerializer}, + {typeof(ProviderDefinedType), new CompositePDTSerializer()}, }; private readonly Dictionary _serializerByDataType = @@ -96,6 +97,7 @@ public class TypeSerializerRegistry {DataType.Char, new CharSerializer()}, {DataType.Duration, new DurationSerializer()}, {DataType.Marker, SingleTypeSerializers.MarkerSerializer}, + {DataType.CompositePDT, new CompositePDTSerializer()}, }; /// diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/Types/CompositePDTSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/Types/CompositePDTSerializer.cs new file mode 100644 index 00000000000..49c4bdff665 --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/Types/CompositePDTSerializer.cs @@ -0,0 +1,78 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Gremlin.Net.Structure.IO.GraphBinary4.Types +{ + /// + /// A serializer for the CompositePDT data type. + /// + public class CompositePDTSerializer : SimpleTypeSerializer + { + /// + /// Initializes a new instance of the class. + /// + public CompositePDTSerializer() : base(DataType.CompositePDT) + { + } + + /// + protected override async Task WriteValueAsync(ProviderDefinedType value, Stream stream, + GraphBinaryWriter writer, CancellationToken cancellationToken = default) + { + // Write name as fully-qualified string + await writer.WriteAsync(value.Name, stream, cancellationToken).ConfigureAwait(false); + // Write properties as fully-qualified map + await writer.WriteAsync((IDictionary)new Dictionary(value.Properties), + stream, cancellationToken).ConfigureAwait(false); + } + + /// + protected override async Task ReadValueAsync(Stream stream, GraphBinaryReader reader, + CancellationToken cancellationToken = default) + { + var name = await reader.ReadAsync(stream, cancellationToken).ConfigureAwait(false) as string; + if (string.IsNullOrEmpty(name)) + throw new IOException("CompositePDT name cannot be null or empty."); + + var map = await reader.ReadAsync(stream, cancellationToken).ConfigureAwait(false) + as IDictionary; + + var properties = new Dictionary(); + if (map != null) + { + foreach (var kv in map) + { + properties[(string)kv.Key] = kv.Value; + } + } + + return new ProviderDefinedType(name!, properties); + } + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IProviderDefinedTypeAdapter.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IProviderDefinedTypeAdapter.cs new file mode 100644 index 00000000000..cd7af23e3cd --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IProviderDefinedTypeAdapter.cs @@ -0,0 +1,49 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System.Collections.Generic; + +namespace Gremlin.Net.Structure +{ + /// + /// Adapter for hydrating a into a strongly-typed object. + /// + /// The target type to hydrate into. + public interface IProviderDefinedTypeAdapter + { + /// + /// Gets the fully-qualified type name this adapter handles. + /// + string TypeName { get; } + + /// + /// Creates a typed instance from the PDT properties. + /// + T FromProperties(IReadOnlyDictionary properties); + + /// + /// Converts a typed instance back to PDT properties. + /// + IReadOnlyDictionary ToProperties(T obj); + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedAttribute.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedAttribute.cs new file mode 100644 index 00000000000..c7d25667e58 --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedAttribute.cs @@ -0,0 +1,81 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Reflection; + +namespace Gremlin.Net.Structure +{ + /// + /// Marks a class as a provider-defined type target for hydration. + /// + [AttributeUsage(AttributeTargets.Class)] + public class ProviderDefinedAttribute : Attribute + { + /// + /// Gets or sets the fully-qualified provider-defined type name. + /// + public string? Name { get; set; } + + /// + /// Gets or sets the list of property names to include during dehydration. + /// If non-null and non-empty, only these properties are serialized. + /// Cannot be combined with . + /// + public string[]? IncludedFields { get; set; } + + /// + /// Gets or sets the list of property names to exclude during dehydration. + /// If non-null and non-empty, these properties are omitted from serialization. + /// Cannot be combined with . + /// + public string[]? ExcludedFields { get; set; } + + /// + /// Static registry of annotated types keyed by PDT name, populated lazily during dehydration. + /// + internal static readonly ConcurrentDictionary RegisteredTypes = new(); + + /// + /// Hydrates a using a registered annotated type. + /// Returns the original PDT if no annotated type is registered for the name. + /// + internal static object HydrateIfRegistered(ProviderDefinedType pdt) + { + if (!RegisteredTypes.TryGetValue(pdt.Name, out var type)) + return pdt; + var obj = Activator.CreateInstance(type)!; + foreach (var (key, value) in pdt.Properties) + { + var prop = type.GetProperty(key, BindingFlags.Public | BindingFlags.Instance); + if (prop != null && prop.CanWrite && value != null) + { + prop.SetValue(obj, Convert.ChangeType(value, prop.PropertyType)); + } + } + return obj; + } + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedType.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedType.cs new file mode 100644 index 00000000000..ac6c9d60f99 --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedType.cs @@ -0,0 +1,70 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Gremlin.Net.Structure +{ + /// + /// Represents a provider-defined type (PDT) with a name and a set of properties. + /// + public class ProviderDefinedType + { + /// + /// Initializes a new instance of the class. + /// + /// The fully-qualified name of the provider-defined type. + /// The properties of the provider-defined type. + public ProviderDefinedType(string name, IReadOnlyDictionary properties) + { + Name = name ?? throw new ArgumentNullException(nameof(name)); + if (string.IsNullOrEmpty(name)) throw new ArgumentException("name cannot be empty", nameof(name)); + Properties = properties ?? new Dictionary(); + } + + /// + /// Gets the fully-qualified name of this provider-defined type. + /// + public string Name { get; } + + /// + /// Gets the properties of this provider-defined type. + /// + public IReadOnlyDictionary Properties { get; } + + /// + public override string ToString() => + $"pdt[{Name}]{{{string.Join(", ", Properties.Select(kv => $"{kv.Key}={kv.Value}"))}}}"; + + /// + public override bool Equals(object? obj) => + obj is ProviderDefinedType other && Name == other.Name && + Properties.Count == other.Properties.Count && + Properties.All(kv => other.Properties.TryGetValue(kv.Key, out var v) && Equals(kv.Value, v)); + + /// + public override int GetHashCode() => HashCode.Combine(Name, Properties.Count); + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedTypeRegistry.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedTypeRegistry.cs new file mode 100644 index 00000000000..b33269c98cc --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/ProviderDefinedTypeRegistry.cs @@ -0,0 +1,143 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Reflection; + +namespace Gremlin.Net.Structure +{ + /// + /// Registry for instances that hydrate + /// values into strongly-typed objects. + /// + public class ProviderDefinedTypeRegistry + { + private readonly Dictionary _adaptersByName = new(); + private readonly Dictionary _adaptersByType = new(); + + /// + /// Registers an adapter for a specific provider-defined type name. + /// + public void Register(IProviderDefinedTypeAdapter adapter) + { + _adaptersByName[adapter.TypeName] = adapter; + _adaptersByType[typeof(T)] = (adapter.TypeName, adapter); + } + + /// + /// Creates a registry populated by scanning loaded assemblies for: + /// + /// Types implementing (adapter-based hydration) + /// Types annotated with (annotation-based round-trip) + /// + /// + public static ProviderDefinedTypeRegistry Build() + { + var registry = new ProviderDefinedTypeRegistry(); + + foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) + { + try + { + foreach (var type in assembly.GetTypes()) + { + // Register adapter implementations + var adapterInterface = type.GetInterfaces() + .FirstOrDefault(i => i.IsGenericType && + i.GetGenericTypeDefinition() == typeof(IProviderDefinedTypeAdapter<>)); + if (adapterInterface != null && !type.IsAbstract && !type.IsInterface) + { + try + { + var adapter = Activator.CreateInstance(type); + var registerMethod = typeof(ProviderDefinedTypeRegistry) + .GetMethod(nameof(Register))! + .MakeGenericMethod(adapterInterface.GetGenericArguments()[0]); + registerMethod.Invoke(registry, new[] { adapter }); + } + catch + { + // skip types that can't be instantiated + } + } + + // Register annotated types for annotation-based round-trip hydration + var pdtAttr = type.GetCustomAttribute(); + if (pdtAttr != null) + { + var typeName = !string.IsNullOrEmpty(pdtAttr.Name) ? pdtAttr.Name : type.Name; + ProviderDefinedAttribute.RegisteredTypes.TryAdd(typeName, type); + } + } + } + catch + { + // skip assemblies that can't be reflected + } + } + + return registry; + } + + /// + /// Returns the type name and ToProperties method for the given CLR type, or null if not registered. + /// + internal (string typeName, Func>)? GetAdapterByType(Type type) + { + if (!_adaptersByType.TryGetValue(type, out var entry)) + return null; + var method = entry.adapter.GetType().GetMethod("ToProperties"); + if (method == null) return null; + return (entry.typeName, obj => (IReadOnlyDictionary)method.Invoke(entry.adapter, new[] { obj })!); + } + + /// + /// Hydrates a into a typed object using a registered adapter. + /// Returns the original PDT if no adapter is registered or if hydration fails. + /// + public object Hydrate(ProviderDefinedType pdt) + { + if (!_adaptersByName.TryGetValue(pdt.Name, out var adapterObj)) + return pdt; + try + { + var hydratedProps = new Dictionary(); + foreach (var (key, value) in pdt.Properties) + { + hydratedProps[key] = value is ProviderDefinedType nested ? Hydrate(nested) : value; + } + + var readOnlyProps = new ReadOnlyDictionary(hydratedProps); + var method = adapterObj.GetType().GetMethod("FromProperties"); + return method!.Invoke(adapterObj, new object[] { readOnlyProps })!; + } + catch (Exception) + { + return pdt; + } + } + } +} diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/DriverRemoteConnectionTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/DriverRemoteConnectionTests.cs index c34c2320066..1f6ce57e014 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/DriverRemoteConnectionTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/DriverRemoteConnectionTests.cs @@ -22,10 +22,14 @@ #endregion using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Threading.Tasks; using Gremlin.Net.Driver; using Gremlin.Net.Driver.Remote; using Gremlin.Net.Process.Traversal; +using Gremlin.Net.Structure; +using Gremlin.Net.Structure.IO.GraphBinary4; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; @@ -92,4 +96,100 @@ private static GremlinLang SomeValidGremlinLang return gremlinLang; } } + + [Fact] + public void ShouldRoundTripPdtViaTraversalApi() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + using var gremlinClient = new GremlinClient(gremlinServer); + using var connection = new DriverRemoteConnection(gremlinClient, "gmodern"); + var g = AnonymousTraversalSource.Traversal().With(connection); + + var pdt = new ProviderDefinedType("TestPoint", + new Dictionary { { "x", 1 }, { "y", 2 } }); + + var results = g.Inject(pdt).ToList(); + + Assert.Single(results); + var result = Assert.IsType(results[0]); + Assert.Equal("TestPoint", result.Name); + Assert.Equal(1, result.Properties["x"]); + Assert.Equal(2, result.Properties["y"]); + } + + [Fact] + public void ShouldRoundTripTypedObjectViaRegistry() + { + var registry = new ProviderDefinedTypeRegistry(); + registry.Register(new TestPointAdapter()); + + var gremlinServer = new GremlinServer(TestHost, TestPort); + using var gremlinClient = new GremlinClient(gremlinServer, pdtRegistry: registry); + using var connection = new DriverRemoteConnection(gremlinClient, "gmodern", pdtRegistry: registry); + var g = AnonymousTraversalSource.Traversal().With(connection); + + var point = new TestPointClass { X = 5, Y = 10 }; + + var results = g.Inject(point).ToList(); + + Assert.Single(results); + var result = Assert.IsType(results[0]); + Assert.Equal(5, result.X); + Assert.Equal(10, result.Y); + } + + [Fact] + public void ShouldRoundTripAnnotatedClass() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + using var gremlinClient = new GremlinClient(gremlinServer); + using var connection = new DriverRemoteConnection(gremlinClient, "gmodern"); + var g = AnonymousTraversalSource.Traversal().With(connection); + + var point = new AnnotatedTestPoint { X = 3, Y = 7 }; + + var results = g.Inject(point).ToList(); + + Assert.Single(results); + var result = Assert.IsType(results[0]); + Assert.Equal(3, result.X); + Assert.Equal(7, result.Y); + } + + #region Test helpers + + private class TestPointClass + { + public int X { get; set; } + public int Y { get; set; } + } + + private class TestPointAdapter : IProviderDefinedTypeAdapter + { + public string TypeName => "TestPoint"; + + public TestPointClass FromProperties(IReadOnlyDictionary properties) + { + return new TestPointClass + { + X = Convert.ToInt32(properties["x"]), + Y = Convert.ToInt32(properties["y"]) + }; + } + + public IReadOnlyDictionary ToProperties(TestPointClass obj) + { + return new ReadOnlyDictionary( + new Dictionary { { "x", obj.X }, { "y", obj.Y } }); + } + } + + [ProviderDefined(Name = "TestPoint")] + private class AnnotatedTestPoint + { + public int X { get; set; } + public int Y { get; set; } + } + + #endregion } \ No newline at end of file diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs index a1fb59efc7c..fab9af53e4f 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs @@ -28,6 +28,7 @@ using Gremlin.Net.Driver.Exceptions; using Gremlin.Net.Driver.Messages; using Gremlin.Net.IntegrationTest.Util; +using Gremlin.Net.Structure; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; @@ -204,5 +205,72 @@ public void ShouldNotLogForDisabledLogLevel() logger.VerifyNothingWasLogged(); } + + [Fact] + public async Task ShouldRoundTripSimplePointPdt() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + using var gremlinClient = new GremlinClient(gremlinServer); + + var response = await gremlinClient.SubmitAsync( + "g.inject(PDT(\"Point\", [\"x\":1, \"y\":2]))"); + var results = await response.ToListAsync(); + + Assert.Single(results); + var pdt = Assert.IsType(results[0]); + Assert.Equal("Point", pdt.Name); + Assert.Equal(2, pdt.Properties.Count); + Assert.Equal(1, pdt.Properties["x"]); + Assert.Equal(2, pdt.Properties["y"]); + } + + [Fact] + public async Task ShouldRoundTripNestedPdt() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + using var gremlinClient = new GremlinClient(gremlinServer); + + var response = await gremlinClient.SubmitAsync( + "g.inject(PDT(\"Person\", [\"name\":\"Alice\", \"age\":30, " + + "\"address\":PDT(\"Address\", [\"street\":\"123 Main St\", \"city\":\"Springfield\", \"zip\":\"12345\"])]))"); + var results = await response.ToListAsync(); + + Assert.Single(results); + var pdt = Assert.IsType(results[0]); + Assert.Equal("Person", pdt.Name); + Assert.Equal("Alice", pdt.Properties["name"]); + Assert.Equal(30, pdt.Properties["age"]); + + var address = Assert.IsType(pdt.Properties["address"]); + Assert.Equal("Address", address.Name); + Assert.Equal("123 Main St", address.Properties["street"]); + Assert.Equal("Springfield", address.Properties["city"]); + Assert.Equal("12345", address.Properties["zip"]); + } + + [Fact] + public async Task ShouldHandlePdtInCollection() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + using var gremlinClient = new GremlinClient(gremlinServer); + + var response = await gremlinClient.SubmitAsync( + "g.inject([PDT(\"Point\", [\"x\":1, \"y\":2]), PDT(\"Point\", [\"x\":3, \"y\":4])])"); + var results = await response.ToListAsync(); + + Assert.Single(results); + var list = Assert.IsType>(results[0]); + Assert.Equal(2, list.Count); + + var p1 = Assert.IsType(list[0]); + Assert.Equal("Point", p1.Name); + Assert.Equal(1, p1.Properties["x"]); + Assert.Equal(2, p1.Properties["y"]); + + var p2 = Assert.IsType(list[1]); + Assert.Equal("Point", p2.Name); + Assert.Equal(3, p2.Properties["x"]); + Assert.Equal(4, p2.Properties["y"]); + } } } diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Gherkin/Gremlin.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Gherkin/Gremlin.cs index 448457ee970..bf88cfccfaf 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Gherkin/Gremlin.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Gherkin/Gremlin.cs @@ -946,9 +946,9 @@ private static IDictionary, ITraversal>> {(g,p) =>g.AddV().Property(Cardinality.Single, new Dictionary {{ "name", "foo" }, { "age", 42 }}), (g,p) =>g.V().Has("name", "foo")}}, {"g_V_hasXname_fooX_propertyXname_setXbarX_age_43X", new List, ITraversal>> {(g,p) =>g.AddV().Property(Cardinality.Single, "name", "foo").Property("age", 42), (g,p) =>g.V().Has("name", "foo").Property(new Dictionary {{ "name", CardinalityValue.Set("bar") }, { "age", 43 }}), (g,p) =>g.V().Has("name", "foo"), (g,p) =>g.V().Has("name", "bar"), (g,p) =>g.V().Has("age", 43), (g,p) =>g.V().Has("age", 42)}}, {"g_V_hasXname_fooX_propertyXset_name_bar_age_singleX43XX", new List, ITraversal>> {(g,p) =>g.AddV().Property(Cardinality.Single, "name", "foo").Property("age", 42), (g,p) =>g.V().Has("name", "foo").Property(Cardinality.Set, new Dictionary {{ "name", "bar" }, { "age", CardinalityValue.Single(43) }}), (g,p) =>g.V().Has("name", "foo"), (g,p) =>g.V().Has("name", "bar"), (g,p) =>g.V().Has("age", 43), (g,p) =>g.V().Has("age", 42)}}, - {"g_addV_propertyXnullX", new List, ITraversal>> {(g,p) =>g.AddV((string) "person").Property((IDictionary) null), (g,p) =>g.V().HasLabel("person").Values()}}, + {"g_addV_propertyXnullX", new List, ITraversal>> {(g,p) =>g.AddV((string) "person").Property(null), (g,p) =>g.V().HasLabel("person").Values()}}, {"g_addV_propertyXemptyX", new List, ITraversal>> {(g,p) =>g.AddV((string) "person").Property(new Dictionary {}), (g,p) =>g.V().HasLabel("person").Values()}}, - {"g_addV_propertyXset_nullX", new List, ITraversal>> {(g,p) =>g.AddV((string) "foo").Property(Cardinality.Set, (IDictionary) null), (g,p) =>g.V().HasLabel("foo").Values()}}, + {"g_addV_propertyXset_nullX", new List, ITraversal>> {(g,p) =>g.AddV((string) "foo").Property(Cardinality.Set, null), (g,p) =>g.V().HasLabel("foo").Values()}}, {"g_addV_propertyXset_emptyX", new List, ITraversal>> {(g,p) =>g.AddV((string) "foo").Property(Cardinality.Set, new Dictionary {}), (g,p) =>g.V().HasLabel("person").Values()}}, {"g_addVXpersonX_propertyXname_joshX_propertyXage_nullX", new List, ITraversal>> {(g,p) =>g.AddV((string) "person").Property("name", "josh").Property("age", null), (g,p) =>g.V().Has("person", "age", (object) null)}}, {"g_addVXpersonX_propertyXname_markoX_propertyXfriendWeight_null_acl_nullX", new List, ITraversal>> {(g,p) =>g.AddV((string) "person").Property("name", "marko").Property("friendWeight", null, "acl", null), (g,p) =>g.V().Has("person", "name", "marko").Has("friendWeight", (object) null), (g,p) =>g.V().Has("person", "name", "marko").Properties("friendWeight").Has("acl", (object) null), (g,p) =>g.V().Has("person", "name", "marko").Properties("friendWeight").Count()}}, diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Process/Traversal/GremlinLangTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Process/Traversal/GremlinLangTests.cs index b4eeade1680..df80b589a68 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Process/Traversal/GremlinLangTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Process/Traversal/GremlinLangTests.cs @@ -1088,5 +1088,91 @@ public void ConvertParametersToString_escaped_string_value() var result = GremlinLang.ConvertParametersToString(parameters); Assert.Contains("\"name\":", result); } + + [Fact] + public void g_Inject_PDT_basic() + { + var pdt = new ProviderDefinedType("Point", new Dictionary { { "x", 1 }, { "y", 2 } }); + var result = _g.Inject((object)pdt).GremlinLang.GetGremlin(); + Assert.Equal("g.inject(PDT(\"Point\",[\"x\":1,\"y\":2]))", result); + } + + [Fact] + public void g_Inject_PDT_special_chars_in_name() + { + var pdt = new ProviderDefinedType("my\"type", new Dictionary { { "a", 1 } }); + var result = _g.Inject((object)pdt).GremlinLang.GetGremlin(); + Assert.Equal("g.inject(PDT(\"my\\\"type\",[\"a\":1]))", result); + } + + [Fact] + public void g_Inject_PDT_nested() + { + var inner = new ProviderDefinedType("Inner", new Dictionary { { "v", 42 } }); + var outer = new ProviderDefinedType("Outer", new Dictionary { { "child", inner } }); + var result = _g.Inject((object)outer).GremlinLang.GetGremlin(); + Assert.Equal("g.inject(PDT(\"Outer\",[\"child\":PDT(\"Inner\",[\"v\":42])]))", result); + } + + [Fact] + public void g_Inject_PDT_auto_dehydration_via_attribute() + { + var point = new TestPoint { X = 10, Y = 20 }; + var result = _g.Inject((object)point).GremlinLang.GetGremlin(); + Assert.Equal("g.inject(PDT(\"geo.Point\",[\"X\":10,\"Y\":20]))", result); + } + + [Fact] + public void g_Inject_PDT_auto_dehydration_IncludedFields() + { + var point = new IncludedFieldsPoint { X = 1, Y = 2, Z = 3 }; + var result = _g.Inject((object)point).GremlinLang.GetGremlin(); + Assert.Equal("g.inject(PDT(\"IncludedFieldsPoint\",[\"X\":1]))", result); + } + + [Fact] + public void g_Inject_PDT_auto_dehydration_ExcludedFields() + { + var point = new ExcludedFieldsPoint { X = 1, Y = 2, Z = 3 }; + var result = _g.Inject((object)point).GremlinLang.GetGremlin(); + Assert.Equal("g.inject(PDT(\"ExcludedFieldsPoint\",[\"X\":1,\"Y\":2]))", result); + } + + [Fact] + public void g_Inject_PDT_auto_dehydration_both_fields_throws() + { + var point = new BothFieldsPoint { X = 1, Y = 2 }; + Assert.Throws(() => _g.Inject((object)point).GremlinLang.GetGremlin()); + } + + [ProviderDefined(Name = "geo.Point")] + private class TestPoint + { + public int X { get; set; } + public int Y { get; set; } + } + + [ProviderDefined(IncludedFields = new[] { "X" })] + private class IncludedFieldsPoint + { + public int X { get; set; } + public int Y { get; set; } + public int Z { get; set; } + } + + [ProviderDefined(ExcludedFields = new[] { "Z" })] + private class ExcludedFieldsPoint + { + public int X { get; set; } + public int Y { get; set; } + public int Z { get; set; } + } + + [ProviderDefined(IncludedFields = new[] { "X" }, ExcludedFields = new[] { "Y" })] + private class BothFieldsPoint + { + public int X { get; set; } + public int Y { get; set; } + } } } diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/ProviderDefinedTypeTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/ProviderDefinedTypeTests.cs new file mode 100644 index 00000000000..36a6318b52e --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/ProviderDefinedTypeTests.cs @@ -0,0 +1,137 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using Gremlin.Net.Structure; +using Gremlin.Net.Structure.IO.GraphBinary4; +using Xunit; + +namespace Gremlin.Net.UnitTest.Structure.IO.GraphBinary4 +{ + public class ProviderDefinedTypeTests + { + private static readonly GraphBinaryWriter Writer = new(); + private static readonly GraphBinaryReader Reader = new(); + + [Fact] + public async Task TestRoundTripWithProperties() + { + var properties = new Dictionary { { "x", 1 }, { "y", "hello" } }; + var expected = new ProviderDefinedType("com.example.MyType", properties); + + using var stream = new MemoryStream(); + await Writer.WriteAsync(expected, stream); + stream.Position = 0; + var actual = await Reader.ReadAsync(stream) as ProviderDefinedType; + + Assert.NotNull(actual); + Assert.Equal(expected.Name, actual!.Name); + Assert.Equal(expected.Properties, actual.Properties); + } + + [Fact] + public async Task TestRoundTripWithEmptyProperties() + { + var expected = new ProviderDefinedType("com.example.Empty", new Dictionary()); + + using var stream = new MemoryStream(); + await Writer.WriteAsync(expected, stream); + stream.Position = 0; + var actual = await Reader.ReadAsync(stream) as ProviderDefinedType; + + Assert.NotNull(actual); + Assert.Equal(expected.Name, actual!.Name); + Assert.Empty(actual.Properties); + } + + [Fact] + public async Task TestRoundTripWithNullPropertyValue() + { + var properties = new Dictionary { { "key", null } }; + var expected = new ProviderDefinedType("com.example.NullVal", properties); + + using var stream = new MemoryStream(); + await Writer.WriteAsync(expected, stream); + stream.Position = 0; + var actual = await Reader.ReadAsync(stream) as ProviderDefinedType; + + Assert.NotNull(actual); + Assert.Equal(expected.Name, actual!.Name); + Assert.Null(actual.Properties["key"]); + } + + [Fact] + public async Task TestDataTypeCode() + { + var pdt = new ProviderDefinedType("com.example.Test", new Dictionary()); + + using var stream = new MemoryStream(); + await Writer.WriteAsync(pdt, stream); + + // First byte should be the CompositePDT type code 0xF0 + Assert.Equal(0xF0, stream.ToArray()[0]); + } + + [Fact] + public void TestConstructorThrowsOnNullName() + { + Assert.Throws(() => + new ProviderDefinedType(null!, new Dictionary())); + } + + [Fact] + public void TestConstructorThrowsOnEmptyName() + { + Assert.Throws(() => + new ProviderDefinedType("", new Dictionary())); + } + + [Fact] + public void TestEquality() + { + var a = new ProviderDefinedType("com.example.T", new Dictionary { { "k", 1 } }); + var b = new ProviderDefinedType("com.example.T", new Dictionary { { "k", 1 } }); + Assert.Equal(a, b); + Assert.Equal(a.GetHashCode(), b.GetHashCode()); + } + + [Fact] + public void TestInequality() + { + var a = new ProviderDefinedType("com.example.A", new Dictionary()); + var b = new ProviderDefinedType("com.example.B", new Dictionary()); + Assert.NotEqual(a, b); + } + + [Fact] + public void TestToString() + { + var pdt = new ProviderDefinedType("com.example.T", new Dictionary { { "x", 42 } }); + Assert.Contains("com.example.T", pdt.ToString()); + Assert.Contains("x=42", pdt.ToString()); + } + } +} diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/ProviderDefinedTypeRegistryTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/ProviderDefinedTypeRegistryTests.cs new file mode 100644 index 00000000000..a9aefc6ee22 --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/ProviderDefinedTypeRegistryTests.cs @@ -0,0 +1,196 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Reflection; +using Gremlin.Net.Structure; +using Xunit; + +namespace Gremlin.Net.UnitTest.Structure +{ + public class ProviderDefinedTypeRegistryTests + { + [Fact] + public void ShouldHydrateToTypedObjectWhenAdapterRegistered() + { + var registry = new ProviderDefinedTypeRegistry(); + registry.Register(new PointAdapter()); + var pdt = new ProviderDefinedType("geo:Point", + new Dictionary { ["x"] = 1.0, ["y"] = 2.0 }); + + var result = registry.Hydrate(pdt); + + var point = Assert.IsType(result); + Assert.Equal(1.0, point.X); + Assert.Equal(2.0, point.Y); + } + + [Fact] + public void ShouldReturnRawPdtWhenNoAdapterRegistered() + { + var registry = new ProviderDefinedTypeRegistry(); + var pdt = new ProviderDefinedType("unknown:Type", + new Dictionary { ["a"] = "b" }); + + var result = registry.Hydrate(pdt); + + Assert.Same(pdt, result); + } + + [Fact] + public void ShouldReturnRawPdtWhenAdapterThrows() + { + var registry = new ProviderDefinedTypeRegistry(); + registry.Register(new ThrowingAdapter()); + var pdt = new ProviderDefinedType("bad:Type", + new Dictionary { ["x"] = "oops" }); + + var result = registry.Hydrate(pdt); + + Assert.Same(pdt, result); + } + + [Fact] + public void ShouldHydrateNestedPdt() + { + var registry = new ProviderDefinedTypeRegistry(); + registry.Register(new PointAdapter()); + registry.Register(new LineAdapter()); + var startPdt = new ProviderDefinedType("geo:Point", + new Dictionary { ["x"] = 0.0, ["y"] = 0.0 }); + var endPdt = new ProviderDefinedType("geo:Point", + new Dictionary { ["x"] = 3.0, ["y"] = 4.0 }); + var linePdt = new ProviderDefinedType("geo:Line", + new Dictionary { ["start"] = startPdt, ["end"] = endPdt }); + + var result = registry.Hydrate(linePdt); + + var line = Assert.IsType(result); + Assert.Equal(0.0, line.Start.X); + Assert.Equal(0.0, line.Start.Y); + Assert.Equal(3.0, line.End.X); + Assert.Equal(4.0, line.End.Y); + } + + [Fact] + public void ShouldHaveProviderDefinedAttributeWithNameProperty() + { + var attr = typeof(AnnotatedPoint).GetCustomAttribute(); + + Assert.NotNull(attr); + Assert.Equal("geo:Point", attr!.Name); + } + + [Fact] + public void BuildShouldReturnRegistryWithoutCrashing() + { + var registry = ProviderDefinedTypeRegistry.Build(); + + Assert.NotNull(registry); + } + + [Fact] + public void BuildShouldDiscoverAdapterFromAssembly() + { + var registry = ProviderDefinedTypeRegistry.Build(); + var pdt = new ProviderDefinedType("test:Discoverable", + new Dictionary { ["value"] = "hello" }); + + var result = registry.Hydrate(pdt); + + var obj = Assert.IsType(result); + Assert.Equal("hello", obj.Value); + } + + #region Test helpers + + private class Point + { + public double X { get; init; } + public double Y { get; init; } + } + + private class Line + { + public Point Start { get; init; } = null!; + public Point End { get; init; } = null!; + } + + [ProviderDefined(Name = "geo:Point")] + private class AnnotatedPoint { } + + private class PointAdapter : IProviderDefinedTypeAdapter + { + public string TypeName => "geo:Point"; + + public Point FromProperties(IReadOnlyDictionary properties) => + new() { X = (double)properties["x"]!, Y = (double)properties["y"]! }; + + public IReadOnlyDictionary ToProperties(Point obj) => + new Dictionary { ["x"] = obj.X, ["y"] = obj.Y }; + } + + private class LineAdapter : IProviderDefinedTypeAdapter + { + public string TypeName => "geo:Line"; + + public Line FromProperties(IReadOnlyDictionary properties) => + new() { Start = (Point)properties["start"]!, End = (Point)properties["end"]! }; + + public IReadOnlyDictionary ToProperties(Line obj) => + new Dictionary { ["start"] = obj.Start, ["end"] = obj.End }; + } + + private class ThrowingAdapter : IProviderDefinedTypeAdapter + { + public string TypeName => "bad:Type"; + + public object FromProperties(IReadOnlyDictionary properties) => + throw new InvalidOperationException("intentional failure"); + + public IReadOnlyDictionary ToProperties(object obj) => + throw new InvalidOperationException("intentional failure"); + } + + #endregion + } + + /// Test type discoverable by Build() assembly scanning. + public class DiscoverableType + { + public string Value { get; init; } = ""; + } + + /// Test adapter discoverable by Build() assembly scanning. + public class DiscoverableTypeAdapter : IProviderDefinedTypeAdapter + { + public string TypeName => "test:Discoverable"; + + public DiscoverableType FromProperties(IReadOnlyDictionary properties) => + new() { Value = (string)properties["value"]! }; + + public IReadOnlyDictionary ToProperties(DiscoverableType obj) => + new Dictionary { ["value"] = obj.Value }; + } +} diff --git a/gremlin-driver/src/main/java/examples/Connections.java b/gremlin-driver/src/main/java/examples/Connections.java index 42fe78da319..16ce3d0652b 100644 --- a/gremlin-driver/src/main/java/examples/Connections.java +++ b/gremlin-driver/src/main/java/examples/Connections.java @@ -24,8 +24,6 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; import org.apache.tinkerpop.gremlin.structure.Graph; -import org.apache.tinkerpop.gremlin.structure.io.AbstractIoRegistry; -import org.apache.tinkerpop.gremlin.structure.io.IoRegistry; import org.apache.tinkerpop.gremlin.structure.io.binary.TypeSerializerRegistry; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerGraph; import org.apache.tinkerpop.gremlin.util.MessageSerializer; @@ -93,8 +91,7 @@ private static void withCluster() throws Exception { // Connecting and specifying a serializer private static void withSerializer() throws Exception { - IoRegistry registry = new FakeIoRegistry(); // an IoRegistry instance exposed by a specific graph provider - TypeSerializerRegistry typeSerializerRegistry = TypeSerializerRegistry.build().addRegistry(registry).create(); + TypeSerializerRegistry typeSerializerRegistry = TypeSerializerRegistry.build().create(); MessageSerializer serializer = new GraphBinaryMessageSerializerV4(typeSerializerRegistry); Cluster cluster = Cluster.build(SERVER_HOST). port(SERVER_PORT). @@ -110,6 +107,4 @@ private static void withSerializer() throws Exception { cluster.close(); g.close(); } - - public static class FakeIoRegistry extends AbstractIoRegistry {} } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/remote/DriverRemoteConnection.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/remote/DriverRemoteConnection.java index baae303ce91..3643aa9415a 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/remote/DriverRemoteConnection.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/remote/DriverRemoteConnection.java @@ -27,6 +27,7 @@ import org.apache.tinkerpop.gremlin.process.traversal.GremlinLang; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; import org.apache.tinkerpop.gremlin.structure.Transaction; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils; import java.util.Optional; @@ -54,6 +55,7 @@ public class DriverRemoteConnection implements RemoteConnection { private transient Optional conf = Optional.empty(); private final boolean attachElements; + private ProviderDefinedTypeRegistry pdtRegistry; public DriverRemoteConnection(final Configuration conf) { final boolean hasClusterConf = IteratorUtils.anyMatch(conf.getKeys(), k -> k.startsWith("clusterConfiguration")); @@ -214,6 +216,18 @@ public static DriverRemoteConnection using(final Configuration conf) { } } + @Override + public ProviderDefinedTypeRegistry getPdtRegistry() { + return pdtRegistry; + } + + /** + * Sets the {@link ProviderDefinedTypeRegistry} for registry-based dehydration in the gremlin-lang translator. + */ + public void setPdtRegistry(final ProviderDefinedTypeRegistry pdtRegistry) { + this.pdtRegistry = pdtRegistry; + } + @Override public CompletableFuture> submitAsync(final GremlinLang gremlinLang) throws RemoteConnectionException { if (gremlinLang.containsUnsupportedTypes()) { diff --git a/gremlin-examples/gremlin-java/Connections.java b/gremlin-examples/gremlin-java/Connections.java index 8dd236a7f2d..966fe6d81be 100644 --- a/gremlin-examples/gremlin-java/Connections.java +++ b/gremlin-examples/gremlin-java/Connections.java @@ -25,8 +25,6 @@ Licensed to the Apache Software Foundation (ASF) under one import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; import org.apache.tinkerpop.gremlin.structure.Graph; -import org.apache.tinkerpop.gremlin.structure.io.AbstractIoRegistry; -import org.apache.tinkerpop.gremlin.structure.io.IoRegistry; import org.apache.tinkerpop.gremlin.structure.io.binary.TypeSerializerRegistry; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerGraph; import org.apache.tinkerpop.gremlin.util.MessageSerializer; @@ -96,8 +94,7 @@ private static void withCluster() throws Exception { // Connecting and specifying a serializer private static void withSerializer() throws Exception { - IoRegistry registry = new FakeIoRegistry(); // an IoRegistry instance exposed by a specific graph provider - TypeSerializerRegistry typeSerializerRegistry = TypeSerializerRegistry.build().addRegistry(registry).create(); + TypeSerializerRegistry typeSerializerRegistry = TypeSerializerRegistry.build().create(); MessageSerializer serializer = new GraphBinaryMessageSerializerV1(typeSerializerRegistry); Cluster cluster = Cluster.build(SERVER_HOST). port(SERVER_PORT). @@ -113,6 +110,4 @@ private static void withSerializer() throws Exception { cluster.close(); g.close(); } - - public static class FakeIoRegistry extends AbstractIoRegistry {} } diff --git a/gremlin-go/driver/client.go b/gremlin-go/driver/client.go index 5bee34ed8e3..4707cfb00b8 100644 --- a/gremlin-go/driver/client.go +++ b/gremlin-go/driver/client.go @@ -61,6 +61,9 @@ type ClientSettings struct { EnableUserAgentOnConnect bool + // PDTRegistry enables automatic hydration of ProviderDefinedType values during deserialization. + PDTRegistry *PDTRegistry + // RequestInterceptors are functions that modify HTTP requests before sending. RequestInterceptors []RequestInterceptor } @@ -105,6 +108,7 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C keepAliveInterval: settings.KeepAliveInterval, enableCompression: settings.EnableCompression, enableUserAgentOnConnect: settings.EnableUserAgentOnConnect, + pdtRegistry: settings.PDTRegistry, } logHandler := newLogHandler(settings.Logger, settings.LogVerbosity, settings.Language) diff --git a/gremlin-go/driver/client_test.go b/gremlin-go/driver/client_test.go index 9b201efb3f9..c234a99918a 100644 --- a/gremlin-go/driver/client_test.go +++ b/gremlin-go/driver/client_test.go @@ -24,6 +24,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestClient(t *testing.T) { @@ -293,3 +294,97 @@ func AssertMarkoVertexWithoutProperties(t *testing.T, result *Result) { assert.True(t, ok) assert.Equal(t, 0, len(properties)) } + +func TestProviderDefinedTypeIntegration(t *testing.T) { + testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) + testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) + + t.Run("simple Point PDT round-trip", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + }) + require.NoError(t, err) + defer client.Close() + + rs, err := client.Submit("g.inject(PDT(\"Point\", [\"x\":1, \"y\":2]))") + require.NoError(t, err) + + result, ok, err := rs.One() + require.NoError(t, err) + require.True(t, ok) + + pdt, ok := result.Data.(*ProviderDefinedType) + require.True(t, ok, "expected *ProviderDefinedType, got %T", result.Data) + assert.Equal(t, "Point", pdt.Name) + assert.Equal(t, int32(1), pdt.Properties["x"]) + assert.Equal(t, int32(2), pdt.Properties["y"]) + }) + + t.Run("nested PDT (Person with Address)", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + }) + require.NoError(t, err) + defer client.Close() + + rs, err := client.Submit( + "g.inject(PDT(\"Person\", [\"name\":\"Alice\", \"age\":30, " + + "\"address\":PDT(\"Address\", [\"street\":\"123 Main St\", \"city\":\"Springfield\", \"zip\":\"12345\"])]))") + require.NoError(t, err) + + result, ok, err := rs.One() + require.NoError(t, err) + require.True(t, ok) + + pdt, ok := result.Data.(*ProviderDefinedType) + require.True(t, ok, "expected *ProviderDefinedType, got %T", result.Data) + assert.Equal(t, "Person", pdt.Name) + assert.Equal(t, "Alice", pdt.Properties["name"]) + assert.Equal(t, int32(30), pdt.Properties["age"]) + + address, ok := pdt.Properties["address"].(*ProviderDefinedType) + require.True(t, ok, "expected nested *ProviderDefinedType for address") + assert.Equal(t, "Address", address.Name) + assert.Equal(t, "123 Main St", address.Properties["street"]) + assert.Equal(t, "Springfield", address.Properties["city"]) + assert.Equal(t, "12345", address.Properties["zip"]) + }) + + t.Run("PDT in collection", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + }) + require.NoError(t, err) + defer client.Close() + + rs, err := client.Submit( + "g.inject([PDT(\"Point\", [\"x\":1, \"y\":2]), PDT(\"Point\", [\"x\":3, \"y\":4])])") + require.NoError(t, err) + + result, ok, err := rs.One() + require.NoError(t, err) + require.True(t, ok) + + list, ok := result.Data.([]interface{}) + require.True(t, ok, "expected []interface{}, got %T", result.Data) + require.Len(t, list, 2) + + p1, ok := list[0].(*ProviderDefinedType) + require.True(t, ok) + assert.Equal(t, "Point", p1.Name) + assert.Equal(t, int32(1), p1.Properties["x"]) + assert.Equal(t, int32(2), p1.Properties["y"]) + + p2, ok := list[1].(*ProviderDefinedType) + require.True(t, ok) + assert.Equal(t, "Point", p2.Name) + assert.Equal(t, int32(3), p2.Properties["x"]) + assert.Equal(t, int32(4), p2.Properties["y"]) + }) +} \ No newline at end of file diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go index 882161e36c7..66947977ecb 100644 --- a/gremlin-go/driver/connection.go +++ b/gremlin-go/driver/connection.go @@ -43,6 +43,7 @@ type connectionSettings struct { keepAliveInterval time.Duration enableCompression bool enableUserAgentOnConnect bool + pdtRegistry *PDTRegistry } // connection handles HTTP request/response for Gremlin queries. @@ -285,7 +286,12 @@ func (c *connection) getReader(resp *http.Response) (io.Reader, io.Closer, error } func (c *connection) streamToResultSet(reader io.Reader, rs ResultSet) { - d := NewGraphBinaryDeserializer(reader) + var d *GraphBinaryDeserializer + if c.connSettings.pdtRegistry != nil { + d = NewGraphBinaryDeserializerWithRegistry(reader, c.connSettings.pdtRegistry) + } else { + d = NewGraphBinaryDeserializer(reader) + } if err := d.ReadHeader(); err != nil { if err != io.EOF { c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) diff --git a/gremlin-go/driver/driverRemoteConnection.go b/gremlin-go/driver/driverRemoteConnection.go index 7239b8d5c46..a5a47a8c517 100644 --- a/gremlin-go/driver/driverRemoteConnection.go +++ b/gremlin-go/driver/driverRemoteConnection.go @@ -59,6 +59,9 @@ type DriverRemoteConnectionSettings struct { // RequestInterceptors are functions that modify HTTP requests before sending. RequestInterceptors []RequestInterceptor + + // PDTRegistry enables registry-based dehydration in the gremlin-lang translator. + PDTRegistry *PDTRegistry } // DriverRemoteConnection is a remote connection. @@ -103,6 +106,7 @@ func NewDriverRemoteConnection( keepAliveInterval: settings.KeepAliveInterval, enableCompression: settings.EnableCompression, enableUserAgentOnConnect: settings.EnableUserAgentOnConnect, + pdtRegistry: settings.PDTRegistry, } logHandler := newLogHandler(settings.Logger, settings.LogVerbosity, settings.Language) diff --git a/gremlin-go/driver/error_codes.go b/gremlin-go/driver/error_codes.go index 6c0dea424c0..ca03def2523 100644 --- a/gremlin-go/driver/error_codes.go +++ b/gremlin-go/driver/error_codes.go @@ -54,7 +54,6 @@ const ( err0406EnumReaderInvalidTypeError errorCode = "E0406_GRAPH_BINARY_ENUMREADER_INVALID_TYPE_ERROR" err0407GetSerializerToWriteUnknownTypeError errorCode = "E0407_GRAPH_BINARY_GETSERIALIZERTOWRITE_UNKNOWN_TYPE_ERROR" err0408GetSerializerToReadUnknownTypeError errorCode = "E0408_GRAPH_BINARY_GETSERIALIZERTOREAD_UNKNOWN_TYPE_ERROR" - err0409GetSerializerToReadUnknownCustomTypeError errorCode = "E0409_GRAPH_BINARY_GETSERIALIZERTOREAD_UNKNOWN_CUSTOM_TYPE_ERROR" // response handling errors err0501ResponseResultSetNotCreatedError errorCode = "E0501_RESPONSE_NO_RESULTSET_ON_DATA_RECEIVE" diff --git a/gremlin-go/driver/graphBinaryDeserializer.go b/gremlin-go/driver/graphBinaryDeserializer.go index 5939dbbda07..8b20080cb04 100644 --- a/gremlin-go/driver/graphBinaryDeserializer.go +++ b/gremlin-go/driver/graphBinaryDeserializer.go @@ -57,10 +57,11 @@ import ( // The bufio.Reader wrapper provides efficient buffering without affecting the // streaming semantics - it simply reduces the number of underlying read syscalls. type GraphBinaryDeserializer struct { - r *bufio.Reader - buf [8]byte - err error // sticky error - bulked bool // whether the response stream uses bulked encoding + r *bufio.Reader + buf [8]byte + err error // sticky error + bulked bool // whether the response stream uses bulked encoding + pdtRegistry *PDTRegistry // optional: auto-hydrates ProviderDefinedType results } // GraphBinary flag for bulked list/set @@ -72,6 +73,12 @@ func NewGraphBinaryDeserializer(r io.Reader) *GraphBinaryDeserializer { return &GraphBinaryDeserializer{r: bufio.NewReaderSize(r, 8192)} } +// NewGraphBinaryDeserializerWithRegistry creates a new GraphBinaryDeserializer with a PDTRegistry +// for automatic hydration of ProviderDefinedType values. +func NewGraphBinaryDeserializerWithRegistry(r io.Reader, registry *PDTRegistry) *GraphBinaryDeserializer { + return &GraphBinaryDeserializer{r: bufio.NewReaderSize(r, 8192), pdtRegistry: registry} +} + func (d *GraphBinaryDeserializer) readByte() (byte, error) { if d.err != nil { return 0, d.err @@ -267,6 +274,8 @@ func (d *GraphBinaryDeserializer) readValue(dt dataType, flag byte) (interface{} return d.readByteBuffer() case tType, directionType, mergeType, gTypeType: return d.readEnum(dt) + case compositePDTType: + return d.readCompositePDT() default: return nil, newError(err0408GetSerializerToReadUnknownTypeError, dt) } @@ -613,6 +622,40 @@ func (d *GraphBinaryDeserializer) readEnum(dt dataType) (interface{}, error) { } } +func (d *GraphBinaryDeserializer) readCompositePDT() (interface{}, error) { + nameObj, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + name, ok := nameObj.(string) + if !ok || name == "" { + return nil, fmt.Errorf("ProviderDefinedType name must be a non-empty string") + } + propsObj, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + var props map[string]interface{} + if propsObj != nil { + raw, ok := propsObj.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("ProviderDefinedType properties must be a map") + } + props = make(map[string]interface{}, len(raw)) + for k, v := range raw { + props[fmt.Sprint(k)] = v + } + } + pdt := &ProviderDefinedType{Name: name, Properties: props} + if d.pdtRegistry != nil { + hydrated := d.pdtRegistry.Hydrate(pdt) + if hydrated != pdt { + return hydrated, nil + } + } + return pdt, nil +} + // ReadStatus reads the response status after the EndOfStream marker. // Returns the status code, message, exception string, and any error encountered. // This should be called after ReadFullyQualified() returns an EndOfStream marker. diff --git a/gremlin-go/driver/graphBinarySerializer.go b/gremlin-go/driver/graphBinarySerializer.go index c473bb880a4..2a353575c7d 100644 --- a/gremlin-go/driver/graphBinarySerializer.go +++ b/gremlin-go/driver/graphBinarySerializer.go @@ -37,7 +37,6 @@ type dataType uint8 // dataType defined as constants. const ( - customType dataType = 0x00 intType dataType = 0x01 longType dataType = 0x02 stringType dataType = 0x03 @@ -64,6 +63,7 @@ const ( mergeType dataType = 0x2e gTypeType dataType = 0x30 durationType dataType = 0x81 + compositePDTType dataType = 0xf0 markerType dataType = 0xfd nullType dataType = 0xFE ) @@ -478,6 +478,8 @@ func (serializer *graphBinaryTypeSerializer) getType(val interface{}) (dataType, return bigDecimalType, nil case *ByteBuffer, ByteBuffer: return byteBuffer, nil + case *ProviderDefinedType: + return compositePDTType, nil default: switch reflect.TypeOf(val).Kind() { case reflect.Map: diff --git a/gremlin-go/driver/graphBinarySerializer_test.go b/gremlin-go/driver/graphBinarySerializer_test.go index 4403ea70c4c..ea43a45e6d9 100644 --- a/gremlin-go/driver/graphBinarySerializer_test.go +++ b/gremlin-go/driver/graphBinarySerializer_test.go @@ -535,3 +535,104 @@ func TestWriterErrorPropagation(t *testing.T) { assert.Equal(t, 22, w.written) }) } + +func TestProviderDefinedTypeSerialization(t *testing.T) { + serializer := graphBinaryTypeSerializer{newLogHandler(&defaultLogger{}, Error, language.English)} + + t.Run("round-trip simple PDT", func(t *testing.T) { + source := &ProviderDefinedType{ + Name: "com.example.MyType", + Properties: map[string]interface{}{"key": "value", "num": int32(42)}, + } + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*ProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, source.Name, pdt.Name) + assert.Equal(t, source.Properties["key"], pdt.Properties["key"]) + assert.Equal(t, source.Properties["num"], pdt.Properties["num"]) + }) + + t.Run("round-trip nested PDT", func(t *testing.T) { + inner := &ProviderDefinedType{ + Name: "com.example.Inner", + Properties: map[string]interface{}{"x": int32(1)}, + } + outer := &ProviderDefinedType{ + Name: "com.example.Outer", + Properties: map[string]interface{}{"child": inner}, + } + var buf bytes.Buffer + err := serializer.write(outer, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*ProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "com.example.Outer", pdt.Name) + child, ok := pdt.Properties["child"].(*ProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "com.example.Inner", child.Name) + assert.Equal(t, int32(1), child.Properties["x"]) + }) + + t.Run("empty name produces error", func(t *testing.T) { + data := []byte{ + 0xf0, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, + } + d := NewGraphBinaryDeserializer(bytes.NewReader(data)) + _, err := d.ReadFullyQualified() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "non-empty") + }) + + t.Run("auto-hydrate with registry", func(t *testing.T) { + registry := NewPDTRegistry() + registry.RegisterFuncs("com.example.MyType", + func(props map[string]interface{}) (interface{}, error) { + return map[string]interface{}{"hydrated": true, "key": props["key"]}, nil + }, nil) + + source := &ProviderDefinedType{ + Name: "com.example.MyType", + Properties: map[string]interface{}{"key": "value"}, + } + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializerWithRegistry(bytes.NewReader(buf.Bytes()), registry) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + m, ok := result.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, true, m["hydrated"]) + assert.Equal(t, "value", m["key"]) + }) + + t.Run("no hydration without registry", func(t *testing.T) { + source := &ProviderDefinedType{ + Name: "com.example.MyType", + Properties: map[string]interface{}{"key": "value"}, + } + var buf bytes.Buffer + err := serializer.write(source, &buf) + assert.Nil(t, err) + + d := NewGraphBinaryDeserializer(bytes.NewReader(buf.Bytes())) + result, err := d.ReadFullyQualified() + assert.Nil(t, err) + pdt, ok := result.(*ProviderDefinedType) + assert.True(t, ok) + assert.Equal(t, "com.example.MyType", pdt.Name) + }) +} diff --git a/gremlin-go/driver/graphTraversalSource.go b/gremlin-go/driver/graphTraversalSource.go index c550bd9556a..6ec85c8f7d4 100644 --- a/gremlin-go/driver/graphTraversalSource.go +++ b/gremlin-go/driver/graphTraversalSource.go @@ -41,6 +41,9 @@ func NewGraphTraversalSource(graph *Graph, remoteConnection *DriverRemoteConnect traversalStrategies ...TraversalStrategy) *GraphTraversalSource { // TODO: revisit when updating strategies gl := NewGremlinLang(nil) + if remoteConnection != nil && remoteConnection.settings != nil && remoteConnection.settings.PDTRegistry != nil { + gl.pdtRegistry = remoteConnection.settings.PDTRegistry + } return &GraphTraversalSource{graph: graph, gremlinLang: gl, remoteConnection: remoteConnection} } @@ -137,6 +140,9 @@ func (gts *GraphTraversalSource) With(key interface{}, value interface{}) *Graph // WithRemote adds a remote to be used throughout the life of a spawned Traversal. func (gts *GraphTraversalSource) WithRemote(remoteConnection *DriverRemoteConnection) *GraphTraversalSource { gts.remoteConnection = remoteConnection + if remoteConnection != nil && remoteConnection.settings != nil && remoteConnection.settings.PDTRegistry != nil { + gts.gremlinLang.pdtRegistry = remoteConnection.settings.PDTRegistry + } if gts.graphTraversal != nil { gts.graphTraversal.remote = remoteConnection } diff --git a/gremlin-go/driver/gremlinlang.go b/gremlin-go/driver/gremlinlang.go index 2625cb10e6b..59b900eab48 100644 --- a/gremlin-go/driver/gremlinlang.go +++ b/gremlin-go/driver/gremlinlang.go @@ -26,6 +26,7 @@ import ( "math" "math/big" "reflect" + "sort" "strconv" "strings" "time" @@ -38,6 +39,7 @@ type GremlinLang struct { gremlin []string parameters map[string]interface{} optionsStrategies []*traversalStrategy + pdtRegistry *PDTRegistry } // NewGremlinLang creates a new GremlinLang to be used in traversals. @@ -45,6 +47,7 @@ func NewGremlinLang(gl *GremlinLang) *GremlinLang { gremlin := make([]string, 0) parameters := make(map[string]interface{}) optionsStrategies := make([]*traversalStrategy, 0) + var registry *PDTRegistry if gl != nil { gremlin = make([]string, len(gl.gremlin)) copy(gremlin, gl.gremlin) @@ -56,12 +59,14 @@ func NewGremlinLang(gl *GremlinLang) *GremlinLang { optionsStrategies = make([]*traversalStrategy, len(gl.optionsStrategies)) copy(optionsStrategies, gl.optionsStrategies) + registry = gl.pdtRegistry } return &GremlinLang{ gremlin: gremlin, parameters: parameters, optionsStrategies: optionsStrategies, + pdtRegistry: registry, } } @@ -200,6 +205,16 @@ func (gl *GremlinLang) argAsString(arg interface{}) (string, error) { case dt: name := reflect.ValueOf(v).Type().Name() return fmt.Sprintf("%s.%s", strings.ToUpper(name), v), nil + case *ProviderDefinedType: + props := v.Properties + if props == nil { + props = map[string]interface{}{} + } + mapStr, err := gl.translateMap(props) + if err != nil { + return "", err + } + return fmt.Sprintf("PDT(\"%s\",%s)", escapeString(v.Name), mapStr), nil case *Vertex: return gl.argAsString(v.Id) case textP: @@ -284,6 +299,17 @@ func (gl *GremlinLang) argAsString(arg interface{}) (string, error) { case []byte: return fmt.Sprintf("Binary(\"%s\")", base64.StdEncoding.EncodeToString(v)), nil default: + // Registry-based dehydration + if gl.pdtRegistry != nil { + adapter := gl.pdtRegistry.GetAdapterByType(reflect.TypeOf(arg)) + if adapter != nil && adapter.ToProperties != nil { + props, err := adapter.ToProperties(arg) + if err == nil { + pdt := &ProviderDefinedType{Name: adapter.TypeName, Properties: props} + return gl.argAsString(pdt) + } + } + } switch reflect.TypeOf(arg).Kind() { case reflect.Map: return gl.translateMap(arg) @@ -308,14 +334,19 @@ func (gl *GremlinLang) translateMap(arg interface{}) (string, error) { if size == 0 { sb.WriteString(":") } else { - iter := reflect.ValueOf(arg).MapRange() - for iter.Next() { - k := iter.Key().Interface() + mapVal := reflect.ValueOf(arg) + keys := mapVal.MapKeys() + // Sort keys for deterministic output (not semantic ordering) + sort.Slice(keys, func(i, j int) bool { + return fmt.Sprintf("%v", keys[i].Interface()) < fmt.Sprintf("%v", keys[j].Interface()) + }) + for idx, key := range keys { + k := key.Interface() kString, err := gl.argAsString(k) if err != nil { return "", err } - v := iter.Value().Interface() + v := mapVal.MapIndex(key).Interface() vString, err := gl.argAsString(v) if err != nil { return "", err @@ -323,8 +354,7 @@ func (gl *GremlinLang) translateMap(arg interface{}) (string, error) { sb.WriteString(kString) sb.WriteByte(':') sb.WriteString(vString) - size-- - if size > 0 { + if idx < len(keys)-1 { sb.WriteString(",") } } diff --git a/gremlin-go/driver/gremlinlang_test.go b/gremlin-go/driver/gremlinlang_test.go index 719bd48f603..c857fb5a91b 100644 --- a/gremlin-go/driver/gremlinlang_test.go +++ b/gremlin-go/driver/gremlinlang_test.go @@ -858,3 +858,66 @@ func Test_ConvertParametersToString(t *testing.T) { new(RequestOptionsBuilder).SetBindings(map[string]interface{}{"x": struct{}{}}).Create() }) } + +func Test_PDT_GremlinLang(t *testing.T) { + t.Run("basic PDT", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &ProviderDefinedType{Name: "MyType", Properties: map[string]interface{}{"x": int32(1), "y": "hello"}} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("MyType",["x":1,"y":"hello"]))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("empty PDT", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &ProviderDefinedType{Name: "Empty", Properties: map[string]interface{}{}} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("Empty",[:]))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("PDT with special characters in name", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &ProviderDefinedType{Name: `say"hello"`, Properties: map[string]interface{}{"v": int32(1)}} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("say\"hello\"",["v":1]))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("PDT with backslash in name", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &ProviderDefinedType{Name: `back\slash`, Properties: map[string]interface{}{"v": int32(1)}} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("back\\slash",["v":1]))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("nested PDT", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + inner := &ProviderDefinedType{Name: "Inner", Properties: map[string]interface{}{"v": int32(1)}} + outer := &ProviderDefinedType{Name: "Outer", Properties: map[string]interface{}{"inner": inner}} + gremlin := g.Inject(outer).GremlinLang.GetGremlin() + expected := `g.inject(PDT("Outer",["inner":PDT("Inner",["v":1])]))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) + + t.Run("PDT map keys sorted", func(t *testing.T) { + g := NewGraphTraversalSource(nil, nil) + pdt := &ProviderDefinedType{Name: "T", Properties: map[string]interface{}{"z": int32(3), "a": int32(1), "m": int32(2)}} + gremlin := g.Inject(pdt).GremlinLang.GetGremlin() + expected := `g.inject(PDT("T",["a":1,"m":2,"z":3]))` + if gremlin != expected { + t.Errorf("got %v, expected %v", gremlin, expected) + } + }) +} diff --git a/gremlin-go/driver/pdtRegistry.go b/gremlin-go/driver/pdtRegistry.go new file mode 100644 index 00000000000..1d14feec6cd --- /dev/null +++ b/gremlin-go/driver/pdtRegistry.go @@ -0,0 +1,104 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import "reflect" + +// PDTAdapter defines how to hydrate/dehydrate a provider-defined type. +type PDTAdapter struct { + TypeName string + FromProperties func(map[string]interface{}) (interface{}, error) + ToProperties func(interface{}) (map[string]interface{}, error) +} + +// PDTRegistry maps type names to their hydration adapters. +type PDTRegistry struct { + adaptersByName map[string]*PDTAdapter + adaptersByType map[reflect.Type]*PDTAdapter +} + +// NewPDTRegistry creates an empty PDTRegistry. +func NewPDTRegistry() *PDTRegistry { + return &PDTRegistry{adaptersByName: make(map[string]*PDTAdapter), adaptersByType: make(map[reflect.Type]*PDTAdapter)} +} + +// RegisterFuncs registers hydration/dehydration functions for a type name. +func (r *PDTRegistry) RegisterFuncs(typeName string, fromProps func(map[string]interface{}) (interface{}, error), toProps func(interface{}) (map[string]interface{}, error)) { + adapter := &PDTAdapter{TypeName: typeName, FromProperties: fromProps, ToProperties: toProps} + r.adaptersByName[typeName] = adapter +} + +// RegisterFuncsWithType registers hydration/dehydration functions for a type name and associates a Go type for dehydration lookup. +func (r *PDTRegistry) RegisterFuncsWithType(typeName string, targetType reflect.Type, fromProps func(map[string]interface{}) (interface{}, error), toProps func(interface{}) (map[string]interface{}, error)) { + adapter := &PDTAdapter{TypeName: typeName, FromProperties: fromProps, ToProperties: toProps} + r.adaptersByName[typeName] = adapter + r.adaptersByType[targetType] = adapter +} + +// RegisterType registers a struct type for reflection-based hydration using "pdt" struct tags. +func (r *PDTRegistry) RegisterType(typeName string, targetType reflect.Type) { + r.adaptersByName[typeName] = &PDTAdapter{ + TypeName: typeName, + FromProperties: func(props map[string]interface{}) (interface{}, error) { + obj := reflect.New(targetType).Elem() + for i := 0; i < targetType.NumField(); i++ { + field := targetType.Field(i) + tag := field.Tag.Get("pdt") + if tag == "" { + tag = field.Name + } + if val, ok := props[tag]; ok && val != nil { + obj.Field(i).Set(reflect.ValueOf(val)) + } + } + return obj.Interface(), nil + }, + } +} + +// GetAdapterByType returns the adapter registered for the given Go type, or nil. +func (r *PDTRegistry) GetAdapterByType(t reflect.Type) *PDTAdapter { + return r.adaptersByType[t] +} + +// Hydrate converts a ProviderDefinedType into a domain object using the registered adapter. +// Returns the raw PDT if no adapter is found or if hydration fails. +func (r *PDTRegistry) Hydrate(pdt *ProviderDefinedType) interface{} { + if pdt == nil { + return nil + } + adapter, ok := r.adaptersByName[pdt.Name] + if !ok { + return pdt + } + hydratedProps := make(map[string]interface{}, len(pdt.Properties)) + for k, v := range pdt.Properties { + if nested, ok := v.(*ProviderDefinedType); ok { + hydratedProps[k] = r.Hydrate(nested) + } else { + hydratedProps[k] = v + } + } + result, err := adapter.FromProperties(hydratedProps) + if err != nil { + return pdt + } + return result +} diff --git a/gremlin-go/driver/pdtRegistry_test.go b/gremlin-go/driver/pdtRegistry_test.go new file mode 100644 index 00000000000..18a97566d8b --- /dev/null +++ b/gremlin-go/driver/pdtRegistry_test.go @@ -0,0 +1,92 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPDTRegistryRegisterFuncsAndHydrate(t *testing.T) { + reg := NewPDTRegistry() + reg.RegisterFuncs("x:Point", func(props map[string]interface{}) (interface{}, error) { + return [2]int{props["x"].(int), props["y"].(int)}, nil + }, nil) + + pdt := &ProviderDefinedType{Name: "x:Point", Properties: map[string]interface{}{"x": 1, "y": 2}} + result := reg.Hydrate(pdt) + assert.Equal(t, [2]int{1, 2}, result) +} + +func TestPDTRegistryNoAdapterReturnsRawPDT(t *testing.T) { + reg := NewPDTRegistry() + pdt := &ProviderDefinedType{Name: "x:Unknown", Properties: map[string]interface{}{"a": "b"}} + result := reg.Hydrate(pdt) + assert.Equal(t, pdt, result) +} + +func TestPDTRegistryAdapterErrorReturnsRawPDT(t *testing.T) { + reg := NewPDTRegistry() + reg.RegisterFuncs("x:Bad", func(props map[string]interface{}) (interface{}, error) { + return nil, errors.New("fail") + }, nil) + + pdt := &ProviderDefinedType{Name: "x:Bad", Properties: map[string]interface{}{}} + result := reg.Hydrate(pdt) + assert.Equal(t, pdt, result) +} + +func TestPDTRegistryNestedHydration(t *testing.T) { + reg := NewPDTRegistry() + reg.RegisterFuncs("x:Inner", func(props map[string]interface{}) (interface{}, error) { + return props["val"].(string) + "!", nil + }, nil) + reg.RegisterFuncs("x:Outer", func(props map[string]interface{}) (interface{}, error) { + return "outer:" + props["child"].(string), nil + }, nil) + + inner := &ProviderDefinedType{Name: "x:Inner", Properties: map[string]interface{}{"val": "hi"}} + outer := &ProviderDefinedType{Name: "x:Outer", Properties: map[string]interface{}{"child": inner}} + result := reg.Hydrate(outer) + assert.Equal(t, "outer:hi!", result) +} + +type testPoint struct { + X int `pdt:"x"` + Y int `pdt:"y"` + L string // no tag, uses field name +} + +func TestPDTRegistryRegisterType(t *testing.T) { + reg := NewPDTRegistry() + reg.RegisterType("x:Point", reflect.TypeOf(testPoint{})) + + pdt := &ProviderDefinedType{Name: "x:Point", Properties: map[string]interface{}{"x": 3, "y": 4, "L": "label"}} + result := reg.Hydrate(pdt) + assert.Equal(t, testPoint{X: 3, Y: 4, L: "label"}, result) +} + +func TestPDTRegistryHydrateNil(t *testing.T) { + reg := NewPDTRegistry() + assert.Nil(t, reg.Hydrate(nil)) +} diff --git a/gremlin-go/driver/providerDefinedType.go b/gremlin-go/driver/providerDefinedType.go new file mode 100644 index 00000000000..c2d969f67e0 --- /dev/null +++ b/gremlin-go/driver/providerDefinedType.go @@ -0,0 +1,51 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "fmt" + "io" +) + +// ProviderDefinedType represents a provider-defined type (PDT) in GraphBinary serialization. +type ProviderDefinedType struct { + Name string + Properties map[string]interface{} +} + +func (p *ProviderDefinedType) String() string { + return fmt.Sprintf("pdt[%s]%v", p.Name, p.Properties) +} + +// pdtWriter serializes a ProviderDefinedType as a fully-qualified string (name) followed by a fully-qualified map (properties). +func pdtWriter(value interface{}, w io.Writer, typeSerializer *graphBinaryTypeSerializer) error { + pdt := value.(*ProviderDefinedType) + if err := typeSerializer.write(pdt.Name, w); err != nil { + return err + } + if pdt.Properties == nil { + return typeSerializer.write(map[interface{}]interface{}{}, w) + } + m := make(map[interface{}]interface{}, len(pdt.Properties)) + for k, v := range pdt.Properties { + m[k] = v + } + return typeSerializer.write(m, w) +} \ No newline at end of file diff --git a/gremlin-go/driver/providerDefinedType_test.go b/gremlin-go/driver/providerDefinedType_test.go new file mode 100644 index 00000000000..91963456d87 --- /dev/null +++ b/gremlin-go/driver/providerDefinedType_test.go @@ -0,0 +1,36 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProviderDefinedType(t *testing.T) { + t.Run("String method", func(t *testing.T) { + pdt := &ProviderDefinedType{ + Name: "com.example.Test", + Properties: map[string]interface{}{"a": int32(1)}, + } + assert.Contains(t, pdt.String(), "pdt[com.example.Test]") + }) +} diff --git a/gremlin-go/driver/serializer.go b/gremlin-go/driver/serializer.go index bd4a4262080..11fec813fc3 100644 --- a/gremlin-go/driver/serializer.go +++ b/gremlin-go/driver/serializer.go @@ -24,7 +24,6 @@ import ( "encoding/binary" "fmt" "io" - "sync" ) const graphBinaryMimeType = "application/vnd.graphbinary-v4.0" @@ -40,22 +39,12 @@ type GraphBinarySerializer struct { ser *graphBinaryTypeSerializer } -// CustomTypeReader user provided function to deserialize custom types -// Deprecated: Custom type deserialization is handled by GraphBinaryDeserializer -type CustomTypeReader func(data *[]byte, i *int) (interface{}, error) - type writer func(interface{}, io.Writer, *graphBinaryTypeSerializer) error var serializers map[dataType]writer -// customTypeReaderLock used to synchronize access to the customDeserializers map -// Deprecated: Custom type deserialization is handled by GraphBinaryDeserializer -var customTypeReaderLock = sync.RWMutex{} -var customDeserializers map[string]CustomTypeReader - func init() { initSerializers() - customDeserializers = map[string]CustomTypeReader{} } func newGraphBinarySerializer(handler *logHandler) *GraphBinarySerializer { @@ -231,21 +220,8 @@ func initSerializers() { setType: setWriter, byteBuffer: byteBufferWriter, markerType: markerWriter, + compositePDTType: pdtWriter, } } -// RegisterCustomTypeReader register a reader (deserializer) for a custom type -// Deprecated: Custom type deserialization should be handled by extending GraphBinaryDeserializer -func RegisterCustomTypeReader(customTypeName string, reader CustomTypeReader) { - customTypeReaderLock.Lock() - defer customTypeReaderLock.Unlock() - customDeserializers[customTypeName] = reader -} -// UnregisterCustomTypeReader unregister a reader (deserializer) for a custom type -// Deprecated: Custom type deserialization should be handled by extending GraphBinaryDeserializer -func UnregisterCustomTypeReader(customTypeName string) { - customTypeReaderLock.Lock() - defer customTypeReaderLock.Unlock() - delete(customDeserializers, customTypeName) -} diff --git a/gremlin-go/driver/traversal_test.go b/gremlin-go/driver/traversal_test.go index bc3fb339163..bc842f10083 100644 --- a/gremlin-go/driver/traversal_test.go +++ b/gremlin-go/driver/traversal_test.go @@ -26,6 +26,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTraversal(t *testing.T) { @@ -855,3 +856,72 @@ func getCount(t *testing.T, g *GraphTraversalSource) int32 { assert.Nil(t, err) return val } + +func TestProviderDefinedTypeTraversalAPIIntegration(t *testing.T) { + testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) + testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) + + t.Run("raw PDT round-trip via Traversal API", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + remote, err := NewDriverRemoteConnection(testNoAuthUrl, + func(settings *DriverRemoteConnectionSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + }) + require.NoError(t, err) + defer remote.Close() + + g := Traversal_().With(remote) + pdt := &ProviderDefinedType{Name: "TestPoint", Properties: map[string]interface{}{"x": int32(1), "y": int32(2)}} + + results, err := g.Inject(pdt).ToList() + require.NoError(t, err) + require.Len(t, results, 1) + + result, ok := results[0].GetInterface().(*ProviderDefinedType) + require.True(t, ok, "expected *ProviderDefinedType, got %T", results[0].GetInterface()) + assert.Equal(t, "TestPoint", result.Name) + assert.Equal(t, int32(1), result.Properties["x"]) + assert.Equal(t, int32(2), result.Properties["y"]) + }) + + t.Run("registry-based round-trip via typed struct", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + registry := NewPDTRegistry() + registry.RegisterFuncsWithType("RegPoint", reflect.TypeOf(regPoint{}), + func(props map[string]interface{}) (interface{}, error) { + return ®Point{X: props["x"].(int32), Y: props["y"].(int32)}, nil + }, + func(obj interface{}) (map[string]interface{}, error) { + p := obj.(regPoint) + return map[string]interface{}{"x": p.X, "y": p.Y}, nil + }) + + remote, err := NewDriverRemoteConnection(testNoAuthUrl, + func(settings *DriverRemoteConnectionSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = testServerModernGraphAlias + settings.PDTRegistry = registry + }) + require.NoError(t, err) + defer remote.Close() + + g := Traversal_().With(remote) + point := regPoint{X: 5, Y: 10} + + results, err := g.Inject(point).ToList() + require.NoError(t, err) + require.Len(t, results, 1) + + result, ok := results[0].GetInterface().(*regPoint) + require.True(t, ok, "expected *regPoint, got %T", results[0].GetInterface()) + assert.Equal(t, int32(5), result.X) + assert.Equal(t, int32(10), result.Y) + }) +} + +// regPoint is a plain struct used for registry-based tests. +type regPoint struct { + X int32 + Y int32 +} \ No newline at end of file diff --git a/gremlin-js/gremlin-javascript/lib/driver/connection.ts b/gremlin-js/gremlin-javascript/lib/driver/connection.ts index 12f7da04674..ec549577d8e 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/connection.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/connection.ts @@ -25,6 +25,7 @@ import { Buffer } from 'buffer'; import { EventEmitter } from 'eventemitter3'; import type { Agent } from 'node:http'; import ioc, { createPreciseReader } from '../structure/io/binary/GraphBinary.js'; +import GraphBinaryReader from '../structure/io/binary/internals/GraphBinaryReader.js'; import StreamReader from '../structure/io/binary/internals/StreamReader.js'; import * as utils from '../utils.js'; import ResultSet from './result-set.js'; @@ -32,7 +33,7 @@ import {RequestMessage} from "./request-message.js"; import ResponseError from './response-error.js'; import { Traverser } from '../process/traversal.js'; -const { graphBinaryReader, graphBinaryWriter } = ioc; +const { graphBinaryWriter } = ioc; const responseStatusCode = { success: 200, @@ -54,6 +55,7 @@ export type ConnectionOptions = { cert?: string | string[] | Buffer; pfx?: string | Buffer; preciseNumbers?: boolean; + pdtRegistry?: any; reader?: any; rejectUnauthorized?: boolean; traversalSource?: string; @@ -88,8 +90,11 @@ export default class Connection extends EventEmitter { ) { super(); - this._reader = options.reader || (options.preciseNumbers === true ? createPreciseReader() : graphBinaryReader); + this._reader = options.reader || (options.preciseNumbers === true ? createPreciseReader() : new GraphBinaryReader(ioc)); this._writer = 'writer' in options ? options.writer : graphBinaryWriter; + if (options.pdtRegistry) { + this._reader.pdtRegistry = options.pdtRegistry; + } this.traversalSource = options.traversalSource || 'g'; this._enableUserAgentOnConnect = options.enableUserAgentOnConnect !== false; diff --git a/gremlin-js/gremlin-javascript/lib/driver/remote-connection.ts b/gremlin-js/gremlin-javascript/lib/driver/remote-connection.ts index 44007e9bd4b..adf149779e8 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/remote-connection.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/remote-connection.ts @@ -39,7 +39,7 @@ export abstract class RemoteConnection { */ constructor( public url: string, - protected readonly options: RemoteConnectionOptions = {}, + public readonly options: RemoteConnectionOptions = {}, ) {} /** diff --git a/gremlin-js/gremlin-javascript/lib/index.ts b/gremlin-js/gremlin-javascript/lib/index.ts index f79c31e948b..ef0e7ce1334 100644 --- a/gremlin-js/gremlin-javascript/lib/index.ts +++ b/gremlin-js/gremlin-javascript/lib/index.ts @@ -25,6 +25,7 @@ import * as t from './process/traversal.js'; import * as gt from './process/graph-traversal.js'; import * as strategiesModule from './process/traversal-strategy.js'; import * as graph from './structure/graph.js'; +import { ProviderDefinedTypeRegistry } from './structure/ProviderDefinedTypeRegistry.js'; import * as rc from './driver/remote-connection.js'; import GremlinLang from './process/gremlin-lang.js'; import * as utils from './utils.js'; @@ -84,6 +85,8 @@ export const structure = { Graph: graph.Graph, Path: graph.Path, Property: graph.Property, + ProviderDefinedType: graph.ProviderDefinedType, + ProviderDefinedTypeRegistry, Vertex: graph.Vertex, VertexProperty: graph.VertexProperty, toLong: utils.toLong, diff --git a/gremlin-js/gremlin-javascript/lib/process/anonymous-traversal.ts b/gremlin-js/gremlin-javascript/lib/process/anonymous-traversal.ts index 3432735ac32..5920c48dc9d 100644 --- a/gremlin-js/gremlin-javascript/lib/process/anonymous-traversal.ts +++ b/gremlin-js/gremlin-javascript/lib/process/anonymous-traversal.ts @@ -59,10 +59,14 @@ export default class AnonymousTraversalSource { with_(connection: RemoteConnection) { const traversalStrategies = new TraversalStrategies(); traversalStrategies.addStrategy(new RemoteStrategy(connection)); + const gl = new GremlinLang(); + if (connection.options?.pdtRegistry) { + gl.pdtRegistry = connection.options.pdtRegistry; + } return new this.traversalSourceClass!( new Graph(), traversalStrategies, - new GremlinLang(), + gl, this.traversalSourceClass, this.traversalClass, ); diff --git a/gremlin-js/gremlin-javascript/lib/process/gremlin-lang.ts b/gremlin-js/gremlin-javascript/lib/process/gremlin-lang.ts index 2a07ad8dec4..f000f521011 100644 --- a/gremlin-js/gremlin-javascript/lib/process/gremlin-lang.ts +++ b/gremlin-js/gremlin-javascript/lib/process/gremlin-lang.ts @@ -20,19 +20,22 @@ import { P, TextP, EnumValue } from './traversal.js'; import { OptionsStrategy, TraversalStrategy } from './traversal-strategy.js'; import { Long, Int, Float, Double, Short, Byte, INT32_MIN, INT32_MAX } from '../utils.js'; -import { Vertex } from '../structure/graph.js'; +import { Vertex, ProviderDefinedType } from '../structure/graph.js'; +import { ProviderDefinedTypeRegistry } from '../structure/ProviderDefinedTypeRegistry.js'; import { Buffer } from 'buffer'; export default class GremlinLang { private gremlin: string = ''; private optionsStrategies: OptionsStrategy[] = []; private parameters: Map = new Map(); + pdtRegistry: ProviderDefinedTypeRegistry | null = null; constructor(toClone?: GremlinLang) { if (toClone) { this.gremlin = toClone.gremlin; this.optionsStrategies = [...toClone.optionsStrategies]; this.parameters = new Map(toClone.parameters); + this.pdtRegistry = toClone.pdtRegistry; } } @@ -128,6 +131,14 @@ export default class GremlinLang { if (typeof arg === 'function' && arg.prototype instanceof TraversalStrategy) { return arg.name; } + if (arg instanceof ProviderDefinedType) { + const props = arg.properties; + const keys = Object.keys(props); + const escapedName = JSON.stringify(arg.name).slice(1, -1); + if (keys.length === 0) return `PDT("${escapedName}",[:])`; + const entries = keys.map(k => `${this._argAsString(k)}:${this._argAsString(props[k])}`); + return `PDT("${escapedName}",[${entries.join(',')}])`; + } if (arg instanceof Vertex) { return this._argAsString(arg.id); } @@ -167,6 +178,14 @@ export default class GremlinLang { if (entries.length === 0) return '[:]'; return '[' + entries.map(([k, v]) => `${this._argAsString(k)}:${this._argAsString(v)}`).join(',') + ']'; } + // Registry-based dehydration + if (this.pdtRegistry && typeof arg === 'object' && arg.constructor) { + const entry = this.pdtRegistry.getAdapterByClass(arg.constructor); + if (entry) { + const props = entry.serialize(arg); + return this._argAsString(new ProviderDefinedType(entry.typeName, props)); + } + } throw new TypeError(`GremlinLang contains at least one type [${arg?.constructor?.name ?? typeof arg}] that cannot be represented as text.`); } diff --git a/gremlin-js/gremlin-javascript/lib/structure/ProviderDefinedTypeRegistry.ts b/gremlin-js/gremlin-javascript/lib/structure/ProviderDefinedTypeRegistry.ts new file mode 100644 index 00000000000..3a8c300aa7b --- /dev/null +++ b/gremlin-js/gremlin-javascript/lib/structure/ProviderDefinedTypeRegistry.ts @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { ProviderDefinedType } from './graph.js'; + +export interface PdtAdapter { + serialize: (obj: any) => Record; + deserialize: (properties: Record) => any; +} + +/** + * A standalone registry that allows users to register adapters for hydrating + * raw {@link ProviderDefinedType} instances into domain-specific objects. + */ +export class ProviderDefinedTypeRegistry { + private readonly _adapters: Map = new Map(); + private readonly _adaptersByClass: Map = new Map(); + + register(typeName: string, adapter: PdtAdapter, targetClass?: Function): void { + this._adapters.set(typeName, adapter); + if (targetClass) { + this._adaptersByClass.set(targetClass, { typeName, adapter }); + } + } + + hydrate(pdt: any): any { + if (!(pdt instanceof ProviderDefinedType)) return pdt; + const adapter = this._adapters.get(pdt.name); + if (!adapter) return pdt; + try { + const hydratedProps: Record = {}; + for (const [k, v] of Object.entries(pdt.properties)) { + hydratedProps[k] = v instanceof ProviderDefinedType ? this.hydrate(v) : v; + } + return adapter.deserialize(hydratedProps); + } catch (e: any) { + console.warn(`PDT hydration failed for '${pdt.name}': ${e.message}`); + return pdt; + } + } + + hasAdapter(typeName: string): boolean { + return this._adapters.has(typeName); + } + + getSerializer(typeName: string): ((obj: any) => Record) | null { + const adapter = this._adapters.get(typeName); + return adapter ? adapter.serialize : null; + } + + getAdapterByClass(cls: Function): { typeName: string; serialize: (obj: any) => Record } | null { + const entry = this._adaptersByClass.get(cls); + if (!entry) return null; + return { typeName: entry.typeName, serialize: entry.adapter.serialize }; + } +} diff --git a/gremlin-js/gremlin-javascript/lib/structure/graph.ts b/gremlin-js/gremlin-javascript/lib/structure/graph.ts index c8b86443601..9398af0aeae 100644 --- a/gremlin-js/gremlin-javascript/lib/structure/graph.ts +++ b/gremlin-js/gremlin-javascript/lib/structure/graph.ts @@ -185,6 +185,24 @@ function areEqual(obj1: any, obj2: any) { return false; } +/** + * Represents a composite Provider Defined Type (PDT). + */ +export class ProviderDefinedType { + readonly name: string; + readonly properties: Readonly>; + + constructor(name: string, properties?: Record) { + if (!name) throw new Error('ProviderDefinedType name cannot be null or empty'); + this.name = name; + this.properties = Object.freeze(properties || {}); + } + + toString() { + return `pdt[${this.name}]${JSON.stringify(this.properties)}`; + } +} + function summarize(value: any) { if (value === null || value === undefined) { return value; diff --git a/gremlin-js/gremlin-javascript/lib/structure/io/binary/GraphBinary.js b/gremlin-js/gremlin-javascript/lib/structure/io/binary/GraphBinary.js index 4bd402b5960..7dbc4df5b13 100644 --- a/gremlin-js/gremlin-javascript/lib/structure/io/binary/GraphBinary.js +++ b/gremlin-js/gremlin-javascript/lib/structure/io/binary/GraphBinary.js @@ -65,6 +65,7 @@ import MarkerSerializer from './internals/MarkerSerializer.js'; import UnspecifiedNullSerializer from './internals/UnspecifiedNullSerializer.js'; import EnumSerializer from './internals/EnumSerializer.js'; import StubSerializer from './internals/StubSerializer.js'; +import CompositePDTSerializer from './internals/CompositePDTSerializer.js'; import NumberSerializationStrategy from './internals/NumberSerializationStrategy.js'; import AnySerializer from './internals/AnySerializer.js'; import GraphBinaryReader from './internals/GraphBinaryReader.js'; @@ -103,13 +104,15 @@ function createIoc(anySerializerOptions) { ioc.markerSerializer = new MarkerSerializer(ioc); ioc.unspecifiedNullSerializer = new UnspecifiedNullSerializer(ioc); ioc.enumSerializer = new EnumSerializer(ioc); + ioc.compositePDTSerializer = new CompositePDTSerializer(ioc); // Register stub serializers for unimplemented v4 types new StubSerializer(ioc, ioc.DataType.TREE, 'Tree'); new StubSerializer(ioc, ioc.DataType.GRAPH, 'Graph'); - new StubSerializer(ioc, ioc.DataType.COMPOSITEPDT, 'CompositePDT'); new StubSerializer(ioc, ioc.DataType.PRIMITIVEPDT, 'PrimitivePDT'); + ioc.pdtRegistry = null; + ioc.numberSerializationStrategy = new NumberSerializationStrategy(ioc); ioc.anySerializer = new AnySerializer(ioc, anySerializerOptions); @@ -171,6 +174,7 @@ export const { markerSerializer, unspecifiedNullSerializer, enumSerializer, + compositePDTSerializer, numberSerializationStrategy, anySerializer, graphBinaryReader, diff --git a/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/AnySerializer.js b/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/AnySerializer.js index bfc20c00162..df78ef18243 100644 --- a/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/AnySerializer.js +++ b/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/AnySerializer.js @@ -43,6 +43,7 @@ export default class AnySerializer { ioc.enumSerializer, ioc.stringSerializer, ioc.binarySerializer, + ioc.compositePDTSerializer, ioc.mapSerializer, ]; } diff --git a/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/CompositePDTSerializer.js b/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/CompositePDTSerializer.js new file mode 100644 index 00000000000..e5a1d8f15ec --- /dev/null +++ b/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/CompositePDTSerializer.js @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Buffer } from 'buffer'; +import { ProviderDefinedType } from '../../../graph.js'; + +export default class CompositePDTSerializer { + constructor(ioc) { + this.ioc = ioc; + this.ioc.serializers[ioc.DataType.COMPOSITEPDT] = this; + } + + canBeUsedFor(value) { + return value instanceof ProviderDefinedType; + } + + serialize(item, fullyQualifiedFormat = true) { + if (item === undefined || item === null) { + if (fullyQualifiedFormat) { + return Buffer.from([this.ioc.DataType.COMPOSITEPDT, 0x01]); + } + const bufs = []; + bufs.push(this.ioc.stringSerializer.serialize('', false)); + bufs.push(this.ioc.mapSerializer.serialize({}, false)); + return Buffer.concat(bufs); + } + + const bufs = []; + if (fullyQualifiedFormat) { + bufs.push(Buffer.from([this.ioc.DataType.COMPOSITEPDT, 0x00])); + } + bufs.push(this.ioc.stringSerializer.serialize(item.name, true)); + bufs.push(this.ioc.mapSerializer.serialize(item.properties, true)); + return Buffer.concat(bufs); + } + + async deserializeValue(reader, valueFlag, typeCode) { + const name = await this.ioc.anySerializer.deserialize(reader); + if (!name) { + throw new Error('CompositePDTSerializer: name cannot be null or empty'); + } + const properties = await this.ioc.anySerializer.deserialize(reader); + const props = properties instanceof Map ? Object.fromEntries(properties) : properties || {}; + const pdt = new ProviderDefinedType(name, props); + const pdtRegistry = reader.pdtRegistry; + if (pdtRegistry) { + const hydrated = pdtRegistry.hydrate(pdt); + if (!(hydrated instanceof ProviderDefinedType)) { + return hydrated; + } + } + return pdt; + } + + async deserialize(reader) { + const type_code = await reader.readUInt8(); + if (type_code !== this.ioc.DataType.COMPOSITEPDT) { + throw new Error(`CompositePDTSerializer: unexpected {type_code}=0x${type_code.toString(16)}`); + } + const value_flag = await reader.readUInt8(); + if (value_flag === 0x01) { + return null; + } + if (value_flag !== 0x00) { + throw new Error(`CompositePDTSerializer: unexpected {value_flag}=0x${value_flag.toString(16)}`); + } + return this.deserializeValue(reader, value_flag, type_code); + } +} diff --git a/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/GraphBinaryReader.js b/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/GraphBinaryReader.js index a2e43175e26..a4bdddabf2f 100644 --- a/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/GraphBinaryReader.js +++ b/gremlin-js/gremlin-javascript/lib/structure/io/binary/internals/GraphBinaryReader.js @@ -40,6 +40,7 @@ const StatusCode = { export default class GraphBinaryReader { constructor(ioc) { this.ioc = ioc; + this.pdtRegistry = null; } get mimeType() { @@ -64,6 +65,7 @@ export default class GraphBinaryReader { } const reader = StreamReader.fromBuffer(buffer); + reader.pdtRegistry = this.pdtRegistry; return await this.#readFromReader(reader); } @@ -81,6 +83,7 @@ export default class GraphBinaryReader { * @returns {AsyncGenerator} */ async *readResponseStream(reader) { + reader.pdtRegistry = this.pdtRegistry; // {version} const version = await reader.readUInt8(); if (version !== 0x84) { diff --git a/gremlin-js/gremlin-javascript/test/integration/client-tests.js b/gremlin-js/gremlin-javascript/test/integration/client-tests.js index aca3548d400..477f4570b21 100644 --- a/gremlin-js/gremlin-javascript/test/integration/client-tests.js +++ b/gremlin-js/gremlin-javascript/test/integration/client-tests.js @@ -18,7 +18,7 @@ */ import assert from 'assert'; -import { Vertex, Edge, VertexProperty } from '../../lib/structure/graph.js'; +import { Vertex, Edge, VertexProperty, ProviderDefinedType } from '../../lib/structure/graph.js'; import { getClient, serverUrl } from '../helper.js'; import { cardinality } from '../../lib/process/traversal.js'; import Client from '../../lib/driver/client.js'; @@ -208,4 +208,69 @@ function assertVertexProperties(vertex) { assert.ok(end); assert.strictEqual(start.value, 1990); assert.strictEqual(end.value, 2000); -} \ No newline at end of file +} + +describe('ProviderDefinedType - Client', function () { + let pdtClient; + before(function () { + pdtClient = getClient('gmodern'); + return pdtClient.open(); + }); + after(function () { + return pdtClient.close(); + }); + + it('should round-trip a simple Point PDT', function () { + return pdtClient.submit('g.inject(PDT("Point", ["x":1, "y":2]))') + .then(function (result) { + assert.strictEqual(result.length, 1); + const pdt = result.first(); + assert.ok(pdt instanceof ProviderDefinedType); + assert.strictEqual(pdt.name, 'Point'); + assert.strictEqual(pdt.properties.x, 1); + assert.strictEqual(pdt.properties.y, 2); + }); + }); + + it('should round-trip a nested PDT (Person with Address)', function () { + return pdtClient.submit( + 'g.inject(PDT("Person", ["name":"Alice", "age":30, ' + + '"address":PDT("Address", ["street":"123 Main St", "city":"Springfield", "zip":"12345"])]))') + .then(function (result) { + assert.strictEqual(result.length, 1); + const pdt = result.first(); + assert.ok(pdt instanceof ProviderDefinedType); + assert.strictEqual(pdt.name, 'Person'); + assert.strictEqual(pdt.properties.name, 'Alice'); + assert.strictEqual(pdt.properties.age, 30); + + const address = pdt.properties.address; + assert.ok(address instanceof ProviderDefinedType); + assert.strictEqual(address.name, 'Address'); + assert.strictEqual(address.properties.street, '123 Main St'); + assert.strictEqual(address.properties.city, 'Springfield'); + assert.strictEqual(address.properties.zip, '12345'); + }); + }); + + it('should handle PDTs in a collection', function () { + return pdtClient.submit( + 'g.inject([PDT("Point", ["x":1, "y":2]), PDT("Point", ["x":3, "y":4])])') + .then(function (result) { + assert.strictEqual(result.length, 1); + const list = result.first(); + assert.ok(Array.isArray(list)); + assert.strictEqual(list.length, 2); + + assert.ok(list[0] instanceof ProviderDefinedType); + assert.strictEqual(list[0].name, 'Point'); + assert.strictEqual(list[0].properties.x, 1); + assert.strictEqual(list[0].properties.y, 2); + + assert.ok(list[1] instanceof ProviderDefinedType); + assert.strictEqual(list[1].name, 'Point'); + assert.strictEqual(list[1].properties.x, 3); + assert.strictEqual(list[1].properties.y, 4); + }); + }); +}); \ No newline at end of file diff --git a/gremlin-js/gremlin-javascript/test/integration/traversal-test.js b/gremlin-js/gremlin-javascript/test/integration/traversal-test.js index 1ada7cabffc..356b01f6d94 100644 --- a/gremlin-js/gremlin-javascript/test/integration/traversal-test.js +++ b/gremlin-js/gremlin-javascript/test/integration/traversal-test.js @@ -23,7 +23,8 @@ import assert from 'assert'; import { AssertionError } from 'assert'; -import {Edge, Vertex, VertexProperty} from '../../lib/structure/graph.js'; +import {Edge, Vertex, VertexProperty, ProviderDefinedType} from '../../lib/structure/graph.js'; +import { ProviderDefinedTypeRegistry } from '../../lib/structure/ProviderDefinedTypeRegistry.js'; import anon from '../../lib/process/anonymous-traversal.js'; import { GraphTraversalSource, GraphTraversal, statics } from '../../lib/process/graph-traversal.js'; import { @@ -31,6 +32,7 @@ import { OptionsStrategy, ReservedKeysVerificationStrategy, EdgeLabelVerificationStrategy, MatchAlgorithmStrategy } from '../../lib/process/traversal-strategy.js'; import GremlinLang from '../../lib/process/gremlin-lang.js'; +import DriverRemoteConnection from '../../lib/driver/driver-remote-connection.js'; import { getConnection, getDriverRemoteConnection } from '../helper.js'; const __ = statics; @@ -433,4 +435,79 @@ describe('Traversal', function () { // assert.ok(!tx._sessionBasedConnection.isOpen); // }); // }); +}); + +let serverUrl; +if (process.env.DOCKER_ENVIRONMENT === 'true') { + serverUrl = 'http://gremlin-server-test-js:45940/gremlin'; +} else { + serverUrl = 'http://localhost:45940/gremlin'; +} + +describe('ProviderDefinedType - Traversal API', function () { + describe('raw PDT round-trip via Traversal API', function () { + let pdtConnection; + + before(function () { + pdtConnection = getConnection('gmodern'); + return pdtConnection.open(); + }); + after(function () { + return pdtConnection.close(); + }); + + it('should round-trip a PDT via g.inject()', async function () { + const g = anon.traversal().with_(pdtConnection); + const pdt = new ProviderDefinedType('TestPoint', { x: 1, y: 2 }); + + const results = await g.inject(pdt).toList(); + + assert.strictEqual(results.length, 1); + const result = results[0]; + assert.ok(result instanceof ProviderDefinedType); + assert.strictEqual(result.name, 'TestPoint'); + assert.strictEqual(result.properties.x, 1); + assert.strictEqual(result.properties.y, 2); + }); + }); + + describe('registry-based round-trip via typed object', function () { + let pdtConnection; + + class TestPoint { + constructor(x, y) { + this.x = x; + this.y = y; + } + } + + before(function () { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('TestPoint', { + serialize: (obj) => ({ x: obj.x, y: obj.y }), + deserialize: (props) => new TestPoint(props.x, props.y), + }, TestPoint); + pdtConnection = new DriverRemoteConnection(serverUrl, { + traversalSource: 'gmodern', + pdtRegistry: registry, + }); + return pdtConnection.open(); + }); + after(function () { + return pdtConnection.close(); + }); + + it('should auto-dehydrate on send and auto-hydrate on receive via registry', async function () { + const g = anon.traversal().with_(pdtConnection); + const point = new TestPoint(5, 10); + + const results = await g.inject(point).toList(); + + assert.strictEqual(results.length, 1); + const result = results[0]; + assert.ok(result instanceof TestPoint); + assert.strictEqual(result.x, 5); + assert.strictEqual(result.y, 10); + }); + }); }); \ No newline at end of file diff --git a/gremlin-js/gremlin-javascript/test/unit/graphbinary/CompositePDTSerializer-test.js b/gremlin-js/gremlin-javascript/test/unit/graphbinary/CompositePDTSerializer-test.js new file mode 100644 index 00000000000..0ac6b282545 --- /dev/null +++ b/gremlin-js/gremlin-javascript/test/unit/graphbinary/CompositePDTSerializer-test.js @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { assert } from 'chai'; +import { ProviderDefinedType } from '../../../lib/structure/graph.js'; +import { ProviderDefinedTypeRegistry } from '../../../lib/structure/ProviderDefinedTypeRegistry.js'; +import ioc, { DataType } from '../../../lib/structure/io/binary/GraphBinary.js'; +import StreamReader from '../../../lib/structure/io/binary/internals/StreamReader.js'; + +const { anySerializer, compositePDTSerializer } = ioc; + +async function roundTrip(value) { + const bytes = anySerializer.serialize(value); + return anySerializer.deserialize(StreamReader.fromBuffer(bytes)); +} + +describe('CompositePDTSerializer', () => { + describe('round-trip: simple PDT', () => { + it('serializes and deserializes a simple ProviderDefinedType', async () => { + const pdt = new ProviderDefinedType('myType', { key1: 'value1', key2: 42 }); + const result = await roundTrip(pdt); + assert.instanceOf(result, ProviderDefinedType); + assert.strictEqual(result.name, 'myType'); + assert.strictEqual(result.properties.key1, 'value1'); + assert.strictEqual(result.properties.key2, 42); + }); + + it('uses COMPOSITEPDT type code', () => { + const pdt = new ProviderDefinedType('test', { a: 1 }); + const bytes = anySerializer.serialize(pdt); + assert.strictEqual(bytes[0], DataType.COMPOSITEPDT); + }); + }); + + describe('round-trip: nested PDT', () => { + it('serializes and deserializes a PDT with nested PDT in properties', async () => { + const inner = new ProviderDefinedType('inner', { x: 'hello' }); + const outer = new ProviderDefinedType('outer', { nested: inner, num: 99 }); + const result = await roundTrip(outer); + assert.instanceOf(result, ProviderDefinedType); + assert.strictEqual(result.name, 'outer'); + assert.strictEqual(result.properties.num, 99); + assert.instanceOf(result.properties.nested, ProviderDefinedType); + assert.strictEqual(result.properties.nested.name, 'inner'); + assert.strictEqual(result.properties.nested.properties.x, 'hello'); + }); + }); + + describe('round-trip: null/undefined field value', () => { + it('handles null property values', async () => { + const pdt = new ProviderDefinedType('withNull', { present: 'yes', absent: null }); + const result = await roundTrip(pdt); + assert.instanceOf(result, ProviderDefinedType); + assert.strictEqual(result.name, 'withNull'); + assert.strictEqual(result.properties.present, 'yes'); + assert.strictEqual(result.properties.absent, null); + }); + }); + + describe('empty name rejected', () => { + it('constructor rejects empty string name', () => { + assert.throws(() => new ProviderDefinedType('', { a: 1 }), /name cannot be null or empty/); + }); + + it('constructor rejects null name', () => { + assert.throws(() => new ProviderDefinedType(null, { a: 1 }), /name cannot be null or empty/); + }); + + it('constructor rejects undefined name', () => { + assert.throws(() => new ProviderDefinedType(undefined, { a: 1 }), /name cannot be null or empty/); + }); + + it('deserializer rejects null name from wire', async () => { + // Manually craft bytes: type_code=0xf0, value_flag=0x00, then null string, then empty map + const nullString = Buffer.from([DataType.STRING, 0x01]); // null string + const emptyMap = Buffer.from([DataType.MAP, 0x00, 0x00, 0x00, 0x00, 0x00]); // map with 0 entries + const bytes = Buffer.concat([ + Buffer.from([DataType.COMPOSITEPDT, 0x00]), + nullString, + emptyMap, + ]); + try { + await anySerializer.deserialize(StreamReader.fromBuffer(bytes)); + assert.fail('should have thrown'); + } catch (e) { + assert.match(e.message, /name cannot be null or empty/); + } + }); + }); + + describe('canBeUsedFor', () => { + it('returns true for ProviderDefinedType instances', () => { + assert.isTrue(compositePDTSerializer.canBeUsedFor(new ProviderDefinedType('t', {}))); + }); + + it('returns false for plain objects', () => { + assert.isFalse(compositePDTSerializer.canBeUsedFor({ name: 'test' })); + }); + + it('returns false for strings', () => { + assert.isFalse(compositePDTSerializer.canBeUsedFor('test')); + }); + }); + + describe('auto-hydration via pdtRegistry', () => { + it('auto-hydrates when pdtRegistry is set on the reader', async () => { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('myType', { + serialize: (obj) => obj, + deserialize: (props) => ({ hydrated: true, ...props }), + }); + + const pdt = new ProviderDefinedType('myType', { key1: 'value1', key2: 42 }); + const bytes = anySerializer.serialize(pdt); + const reader = StreamReader.fromBuffer(bytes); + reader.pdtRegistry = registry; + const result = await anySerializer.deserialize(reader); + + assert.notInstanceOf(result, ProviderDefinedType); + assert.strictEqual(result.hydrated, true); + assert.strictEqual(result.key1, 'value1'); + assert.strictEqual(result.key2, 42); + }); + + it('returns raw PDT when no pdtRegistry is set', async () => { + const pdt = new ProviderDefinedType('myType', { key1: 'value1' }); + const bytes = anySerializer.serialize(pdt); + const result = await anySerializer.deserialize(StreamReader.fromBuffer(bytes)); + + assert.instanceOf(result, ProviderDefinedType); + assert.strictEqual(result.name, 'myType'); + }); + }); +}); diff --git a/gremlin-js/gremlin-javascript/test/unit/graphbinary/precise-mode-test.js b/gremlin-js/gremlin-javascript/test/unit/graphbinary/precise-mode-test.js index c2b9246713e..259642aa897 100644 --- a/gremlin-js/gremlin-javascript/test/unit/graphbinary/precise-mode-test.js +++ b/gremlin-js/gremlin-javascript/test/unit/graphbinary/precise-mode-test.js @@ -515,7 +515,7 @@ describe('Precise Mode Tests', () => { it('default uses the default reader', () => { const conn = new Connection('http://localhost:8182', {}); - assert.strictEqual(conn._reader, graphBinaryReader); + assert.ok(conn._reader instanceof graphBinaryReader.constructor); }); }); }); diff --git a/gremlin-js/gremlin-javascript/test/unit/gremlin-lang-test.js b/gremlin-js/gremlin-javascript/test/unit/gremlin-lang-test.js index a8c4e9add02..24cc0c4d648 100644 --- a/gremlin-js/gremlin-javascript/test/unit/gremlin-lang-test.js +++ b/gremlin-js/gremlin-javascript/test/unit/gremlin-lang-test.js @@ -25,7 +25,7 @@ import { P, TextP, t as T, order as Order, scope as Scope, column as Column, withOptions as WithOptions, direction } from '../../lib/process/traversal.js'; import { ReadOnlyStrategy, SubgraphStrategy, OptionsStrategy, PartitionStrategy, SeedStrategy } from '../../lib/process/traversal-strategy.js'; -import { Graph, Vertex } from '../../lib/structure/graph.js'; +import { Graph, Vertex, ProviderDefinedType } from '../../lib/structure/graph.js'; import { TraversalStrategies } from '../../lib/process/traversal-strategy.js'; import { Long, toFloat, toDouble, toShort, toByte, toInt, toLong } from '../../lib/utils.js'; import GremlinLang from '../../lib/process/gremlin-lang.js'; @@ -626,4 +626,39 @@ describe('GremlinLang', function () { assert.ok(result.includes("'name':'marko'")); }); }); -}); \ No newline at end of file + + describe('PDT gremlin-lang tests', function () { + it('should handle basic PDT', function () { + const pdt = new ProviderDefinedType('Point', { x: 1, y: 2 }); + assert.strictEqual( + g.inject(pdt).getGremlinLang().getGremlin(), + "g.inject(PDT(\"Point\",['x':1,'y':2]))" + ); + }); + + it('should handle PDT with special chars in name (quotes)', function () { + const pdt = new ProviderDefinedType('my"type', { a: 1 }); + assert.strictEqual( + g.inject(pdt).getGremlinLang().getGremlin(), + "g.inject(PDT(\"my\\\"type\",['a':1]))" + ); + }); + + it('should handle nested PDT', function () { + const inner = new ProviderDefinedType('Inner', { v: 42 }); + const outer = new ProviderDefinedType('Outer', { child: inner }); + assert.strictEqual( + g.inject(outer).getGremlinLang().getGremlin(), + "g.inject(PDT(\"Outer\",['child':PDT(\"Inner\",['v':42])]))" + ); + }); + + it('should handle PDT with empty properties', function () { + const pdt = new ProviderDefinedType('Empty', {}); + assert.strictEqual( + g.inject(pdt).getGremlinLang().getGremlin(), + "g.inject(PDT(\"Empty\",[:]))" + ); + }); + }); +}); diff --git a/gremlin-js/gremlin-javascript/test/unit/pdt-registry-test.js b/gremlin-js/gremlin-javascript/test/unit/pdt-registry-test.js new file mode 100644 index 00000000000..7f17544bdf8 --- /dev/null +++ b/gremlin-js/gremlin-javascript/test/unit/pdt-registry-test.js @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { assert } from 'chai'; +import { ProviderDefinedType } from '../../lib/structure/graph.js'; +import { ProviderDefinedTypeRegistry } from '../../lib/structure/ProviderDefinedTypeRegistry.js'; +import Client from '../../lib/driver/client.js'; +import Connection from '../../lib/driver/connection.js'; + +describe('ProviderDefinedTypeRegistry', () => { + describe('#hydrate()', () => { + it('should return a typed object when an adapter is registered', () => { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('GeoPoint', { + serialize: (obj) => ({ lat: obj.lat, lon: obj.lon }), + deserialize: (props) => ({ type: 'GeoPoint', lat: props.lat, lon: props.lon }), + }); + + const pdt = new ProviderDefinedType('GeoPoint', { lat: 37.7749, lon: -122.4194 }); + const result = registry.hydrate(pdt); + + assert.deepStrictEqual(result, { type: 'GeoPoint', lat: 37.7749, lon: -122.4194 }); + }); + + it('should return the raw PDT when no adapter is registered', () => { + const registry = new ProviderDefinedTypeRegistry(); + const pdt = new ProviderDefinedType('Unknown', { foo: 'bar' }); + const result = registry.hydrate(pdt); + + assert.strictEqual(result, pdt); + assert.instanceOf(result, ProviderDefinedType); + }); + + it('should fall back gracefully when adapter throws', () => { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('Broken', { + serialize: () => ({}), + deserialize: () => { throw new Error('adapter error'); }, + }); + + const pdt = new ProviderDefinedType('Broken', { x: 1 }); + const warnings = []; + const origWarn = console.warn; + console.warn = (msg) => warnings.push(msg); + try { + const result = registry.hydrate(pdt); + assert.strictEqual(result, pdt); + assert.lengthOf(warnings, 1); + assert.include(warnings[0], 'adapter error'); + assert.include(warnings[0], 'Broken'); + } finally { + console.warn = origWarn; + } + }); + + it('should recursively hydrate nested PDTs', () => { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('Address', { + serialize: (obj) => obj, + deserialize: (props) => ({ type: 'Address', city: props.city, zip: props.zip }), + }); + registry.register('Person', { + serialize: (obj) => obj, + deserialize: (props) => ({ type: 'Person', name: props.name, address: props.address }), + }); + + const addressPdt = new ProviderDefinedType('Address', { city: 'Portland', zip: '97201' }); + const personPdt = new ProviderDefinedType('Person', { name: 'Alice', address: addressPdt }); + + const result = registry.hydrate(personPdt); + + assert.deepStrictEqual(result, { + type: 'Person', + name: 'Alice', + address: { type: 'Address', city: 'Portland', zip: '97201' }, + }); + }); + + it('should return non-PDT values unchanged', () => { + const registry = new ProviderDefinedTypeRegistry(); + assert.strictEqual(registry.hydrate('hello'), 'hello'); + assert.strictEqual(registry.hydrate(42), 42); + assert.strictEqual(registry.hydrate(null), null); + }); + }); + + describe('#hasAdapter()', () => { + it('should return true for registered types', () => { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('Foo', { serialize: () => ({}), deserialize: (p) => p }); + assert.isTrue(registry.hasAdapter('Foo')); + assert.isFalse(registry.hasAdapter('Bar')); + }); + }); + + describe('#getSerializer()', () => { + it('should return the serialize function for registered types', () => { + const registry = new ProviderDefinedTypeRegistry(); + const serFn = (obj) => ({ val: obj.val }); + registry.register('Custom', { serialize: serFn, deserialize: (p) => p }); + assert.strictEqual(registry.getSerializer('Custom'), serFn); + }); + + it('should return null for unregistered types', () => { + const registry = new ProviderDefinedTypeRegistry(); + assert.isNull(registry.getSerializer('Missing')); + }); + }); +}); + +describe('pdtRegistry wiring through Client/Connection', () => { + it('should set pdtRegistry on the reader when passed via Connection options', () => { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('GeoPoint', { + serialize: (obj) => obj, + deserialize: (props) => ({ type: 'GeoPoint', ...props }), + }); + + const conn = new Connection('http://localhost:8182', { pdtRegistry: registry }); + assert.strictEqual(conn._reader.pdtRegistry, registry); + }); + + it('should set pdtRegistry on the reader when passed via Client options', () => { + const registry = new ProviderDefinedTypeRegistry(); + registry.register('GeoPoint', { + serialize: (obj) => obj, + deserialize: (props) => ({ type: 'GeoPoint', ...props }), + }); + + const client = new Client('http://localhost:8182', { pdtRegistry: registry }); + assert.strictEqual(client._connection._reader.pdtRegistry, registry); + }); + + it('should not leak pdtRegistry between connections', () => { + const registry = new ProviderDefinedTypeRegistry(); + const conn1 = new Connection('http://localhost:8182', { pdtRegistry: registry }); + const conn2 = new Connection('http://localhost:8182'); + assert.isNull(conn2._reader.pdtRegistry); + assert.strictEqual(conn1._reader.pdtRegistry, registry); + }); +}); diff --git a/gremlin-language/src/main/antlr4/Gremlin.g4 b/gremlin-language/src/main/antlr4/Gremlin.g4 index 6798ebde55b..fbe8d414048 100644 --- a/gremlin-language/src/main/antlr4/Gremlin.g4 +++ b/gremlin-language/src/main/antlr4/Gremlin.g4 @@ -1636,6 +1636,7 @@ genericLiteral | characterLiteral | durationLiteral | binaryLiteral + | pdtLiteral | genericMapLiteral ; @@ -1728,6 +1729,10 @@ binaryLiteral : K_BINARYC LPAREN stringLiteral RPAREN ; +pdtLiteral + : K_PDT LPAREN stringLiteral COMMA genericMapLiteral RPAREN + ; + nakedKey : Identifier @@ -1951,6 +1956,7 @@ keyword | K_PAGERANKU | K_PATH | K_PATHU + | K_PDT | K_PEERPRESSURE | K_PEERPRESSUREU | K_PICK @@ -2263,6 +2269,7 @@ K_PAGERANKU: 'PageRank'; K_PAGERANK: 'pageRank'; K_PATH: 'path'; K_PATHU: 'PATH'; +K_PDT: 'PDT'; K_PEERPRESSUREU: 'PeerPressure'; K_PEERPRESSURE: 'peerPressure'; K_PICK: 'Pick'; diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py b/gremlin-python/src/main/python/gremlin_python/driver/client.py index 8b5ff1466c5..5c1544db94d 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/client.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py @@ -43,7 +43,7 @@ def __init__(self, url, traversal_source, pool_size=None, max_workers=None, request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=None, interceptors=None, auth=None, headers=None, enable_user_agent_on_connect=True, - bulk_results=False, **transport_kwargs): + bulk_results=False, pdt_registry=None, **transport_kwargs): log.info("Creating Client with url '%s'", url) self._closed = False @@ -54,6 +54,10 @@ def __init__(self, url, traversal_source, pool_size=None, max_workers=None, self._traversal_source = traversal_source if response_serializer is None: response_serializer = serializer.GraphBinarySerializersV4() + if pdt_registry is not None: + if request_serializer is not None: + request_serializer.configure_pdt_registry(pdt_registry) + response_serializer.configure_pdt_registry(pdt_registry) self._auth = auth self._response_serializer = response_serializer diff --git a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py index fdd25280cf9..a150f3a2332 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py @@ -36,7 +36,7 @@ def __init__(self, url, traversal_source="g", request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=None, interceptors=None, auth=None, headers=None, enable_user_agent_on_connect=True, - bulk_results=False, **transport_kwargs): + bulk_results=False, pdt_registry=None, **transport_kwargs): log.info("Creating DriverRemoteConnection with url '%s'", str(url)) self.__url = url self.__traversal_source = traversal_source @@ -47,6 +47,7 @@ def __init__(self, url, traversal_source="g", self.__enable_user_agent_on_connect = enable_user_agent_on_connect self.__bulk_results = bulk_results self.__transport_kwargs = transport_kwargs + self.pdt_registry = pdt_registry if response_serializer is None: response_serializer = serializer.GraphBinarySerializersV4() @@ -59,6 +60,7 @@ def __init__(self, url, traversal_source="g", headers=headers, enable_user_agent_on_connect=enable_user_agent_on_connect, bulk_results=bulk_results, + pdt_registry=pdt_registry, **transport_kwargs) self._url = self._client._url self._traversal_source = self._client._traversal_source diff --git a/gremlin-python/src/main/python/gremlin_python/driver/serializer.py b/gremlin-python/src/main/python/gremlin_python/driver/serializer.py index 5c51f2507b4..c333a4cc540 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/serializer.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/serializer.py @@ -34,17 +34,24 @@ class GraphBinarySerializersV4(object): int_pack = graphbinaryV4.int32_pack - def __init__(self, reader=None, writer=None, version=None): + def __init__(self, reader=None, writer=None, version=None, pdt_registry=None): if not version: version = self.DEFAULT_VERSION self._version = version if not reader: - reader = self.DEFAULT_READER_CLASS() + reader = self.DEFAULT_READER_CLASS(pdt_registry=pdt_registry) self._graphbinary_reader = reader if not writer: writer = self.DEFAULT_WRITER_CLASS() self._graphbinary_writer = writer + def configure_pdt_registry(self, pdt_registry): + if self._graphbinary_reader.pdt_registry is None: + self._graphbinary_reader.pdt_registry = pdt_registry + else: + self._graphbinary_reader.pdt_registry._adapters_by_name.update(pdt_registry._adapters_by_name) + self._graphbinary_reader.pdt_registry._adapters_by_class.update(pdt_registry._adapters_by_class) + @property def version(self): """Read only property""" diff --git a/gremlin-python/src/main/python/gremlin_python/process/graph_traversal.py b/gremlin-python/src/main/python/gremlin_python/process/graph_traversal.py index 082df46d1ae..9f7d1ea29e5 100644 --- a/gremlin-python/src/main/python/gremlin_python/process/graph_traversal.py +++ b/gremlin-python/src/main/python/gremlin_python/process/graph_traversal.py @@ -46,6 +46,8 @@ def __init__(self, graph, traversal_strategies, gremlin_lang=None, remote_connec self.graph_traversal = GraphTraversal if remote_connection: self.traversal_strategies.add_strategies([RemoteStrategy(remote_connection)]) + if hasattr(remote_connection, 'pdt_registry') and remote_connection.pdt_registry is not None: + self.gremlin_lang.pdt_registry = remote_connection.pdt_registry self.remote_connection = remote_connection def __repr__(self): diff --git a/gremlin-python/src/main/python/gremlin_python/process/traversal.py b/gremlin-python/src/main/python/gremlin_python/process/traversal.py index 76c91822e75..93f835c57f9 100644 --- a/gremlin-python/src/main/python/gremlin_python/process/traversal.py +++ b/gremlin-python/src/main/python/gremlin_python/process/traversal.py @@ -24,7 +24,7 @@ import warnings from aenum import Enum -from gremlin_python.structure.graph import Vertex, Edge, Path, Property +from gremlin_python.structure.graph import Vertex, Edge, Path, Property, ProviderDefinedType from .. import statics from ..statics import long, SingleByte, SingleChar, short, bigint, BigDecimal @@ -814,11 +814,13 @@ def __init__(self, gremlin_lang=None): self.gremlin = [] self.parameters = {} self.options_strategies = [] + self.pdt_registry = None if gremlin_lang is not None: self.gremlin = list(gremlin_lang.gremlin) self.parameters = dict(gremlin_lang.parameters) self.options_strategies = list(gremlin_lang.options_strategies) + self.pdt_registry = gremlin_lang.pdt_registry def _add_to_gremlin(self, string_name, *args): @@ -904,6 +906,9 @@ def _arg_as_string(self, arg): else: return tmp + if isinstance(arg, ProviderDefinedType): + return f'PDT({self._arg_as_string(arg.name)},{self._process_dict(arg.properties)})' + if isinstance(arg, Vertex): return f'{self._arg_as_string(arg.id)}' @@ -947,6 +952,25 @@ def _arg_as_string(self, arg): if isinstance(arg, type): return arg.__name__ + # Registry-based dehydration + if self.pdt_registry is not None: + adapter = self.pdt_registry.get_adapter_by_class(type(arg)) + if adapter is not None and adapter['serialize'] is not None: + props = adapter['serialize'](arg) + return self._arg_as_string(ProviderDefinedType(adapter['type_name'], props)) + + # Auto-dehydrate @provider_defined decorated objects + if hasattr(arg, '_pdt_name'): + included = getattr(arg, '_pdt_included_fields', None) + excluded = getattr(arg, '_pdt_excluded_fields', None) + fields = [f for f in vars(arg) if not f.startswith('_')] + if included: + fields = [f for f in fields if f in included] + elif excluded: + fields = [f for f in fields if f not in excluded] + pdt = ProviderDefinedType(arg._pdt_name, {f: getattr(arg, f) for f in fields}) + return self._arg_as_string(pdt) + raise TypeError(f'GremlinLang contains at least one type [{type(arg).__name__}] that cannot be represented as text.') # Do special processing needed to format predicates that come in diff --git a/gremlin-python/src/main/python/gremlin_python/structure/graph.py b/gremlin-python/src/main/python/gremlin_python/structure/graph.py index 2ff6e56174a..61cec674d83 100644 --- a/gremlin-python/src/main/python/gremlin_python/structure/graph.py +++ b/gremlin-python/src/main/python/gremlin_python/structure/graph.py @@ -141,3 +141,114 @@ def __getitem__(self, key): def __len__(self): return len(self.objects) + + +class ProviderDefinedType(object): + def __init__(self, name, properties): + if not name: + raise ValueError("name cannot be null or empty") + self._name = name + self._properties = dict(properties) if properties else {} + + @property + def name(self): + return self._name + + @property + def properties(self): + return self._properties + + def __eq__(self, other): + return isinstance(other, ProviderDefinedType) and self._name == other._name and self._properties == other._properties + + def __hash__(self): + try: + return hash((self._name, frozenset(self._properties.items()))) + except TypeError: + return hash(self._name) + + def __repr__(self): + return f"pdt[{self._name}]{self._properties}" + + +class ProviderDefinedTypeRegistry(object): + def __init__(self): + self._adapters_by_name = {} + self._adapters_by_class = {} + + def register(self, type_name, deserialize_fn, serialize_fn=None, target_class=None): + self._adapters_by_name[type_name] = { + 'deserialize': deserialize_fn, + 'serialize': serialize_fn, + 'target_class': target_class + } + if target_class is not None: + self._adapters_by_class[target_class] = { + 'type_name': type_name, + 'serialize': serialize_fn, + } + + @classmethod + def build(cls): + """Create a registry populated by entry_points discovery. + + Providers register adapters via pyproject.toml: + [project.entry-points."tinkerpop.pdt"] + my_types = "my_package:register_pdt_types" + + Each entry point should be a callable that accepts a registry and registers adapters. + """ + import sys + registry = cls() + if sys.version_info >= (3, 10): + from importlib.metadata import entry_points + eps = entry_points(group='tinkerpop.pdt') + else: + from importlib.metadata import entry_points + all_eps = entry_points() + eps = all_eps.get('tinkerpop.pdt', []) + + for ep in eps: + try: + factory = ep.load() + factory(registry) + except Exception as e: + import logging + logging.getLogger(__name__).warning( + f"Failed to load PDT adapter from entry point '{ep.name}': {e}") + return registry + + def hydrate(self, pdt): + """Attempt to hydrate a ProviderDefinedType. Returns typed object or raw PDT.""" + if not isinstance(pdt, ProviderDefinedType): + return pdt + adapter = self._adapters_by_name.get(pdt.name) + if adapter is None: + return pdt + try: + hydrated_props = {k: self.hydrate(v) if isinstance(v, ProviderDefinedType) else v + for k, v in pdt.properties.items()} + return adapter['deserialize'](hydrated_props) + except Exception as e: + import logging + logging.getLogger(__name__).warning(f"PDT hydration failed for '{pdt.name}': {e}") + return pdt + + def get_adapter_by_class(self, cls): + """Return (type_name, serialize_fn) tuple for the given class, or None.""" + return self._adapters_by_class.get(cls) + + +# Module-level registry of @provider_defined decorated classes keyed by PDT name. +_pdt_decorated_types = {} + + +def provider_defined(name=None, included_fields=None, excluded_fields=None): + """Decorator that marks a class as a Provider Defined Type.""" + def decorator(cls): + cls._pdt_name = name or cls.__name__ + cls._pdt_included_fields = included_fields + cls._pdt_excluded_fields = excluded_fields + _pdt_decorated_types[cls._pdt_name] = cls + return cls + return decorator diff --git a/gremlin-python/src/main/python/gremlin_python/structure/io/graphbinaryV4.py b/gremlin-python/src/main/python/gremlin_python/structure/io/graphbinaryV4.py index 4cdc3a14eec..cc5d9518fd3 100644 --- a/gremlin-python/src/main/python/gremlin_python/structure/io/graphbinaryV4.py +++ b/gremlin-python/src/main/python/gremlin_python/structure/io/graphbinaryV4.py @@ -30,7 +30,8 @@ from gremlin_python.process.traversal import Direction, T, Merge, GType from gremlin_python.statics import FloatType, BigDecimal, ShortType, IntType, LongType, BigIntType, \ DictType, SetType, SingleByte, SingleChar -from gremlin_python.structure.graph import Graph, Edge, Property, Vertex, VertexProperty, Path +from gremlin_python.structure.graph import Graph, Edge, Property, Vertex, VertexProperty, Path, ProviderDefinedType, \ + _pdt_decorated_types from gremlin_python.structure.io.util import HashableDict, SymbolUtil, Marker log = logging.getLogger(__name__) @@ -73,8 +74,8 @@ class DataType(Enum): gtype = 0x30 char = 0x80 duration = 0x81 + composite_pdt = 0xf0 marker = 0xfd - custom = 0x00 # todo NULL_BYTES = [DataType.null.value, 0x01] @@ -144,10 +145,11 @@ def to_dict(self, obj, to_extend=None): class GraphBinaryReader(object): - def __init__(self, deserializer_map=None): + def __init__(self, deserializer_map=None, pdt_registry=None): self.deserializers = _deserializers.copy() if deserializer_map: self.deserializers.update(deserializer_map) + self.pdt_registry = pdt_registry def read_object(self, b): if b is None: @@ -163,9 +165,33 @@ def to_object(self, buff, data_type=None, nullable=True): if nullable: buff.read(1) return None - return self.deserializers[DataType(bt)].objectify(buff, self, nullable) + result = self.deserializers[DataType(bt)].objectify(buff, self, nullable) else: - return self.deserializers[data_type].objectify(buff, self, nullable) + result = self.deserializers[data_type].objectify(buff, self, nullable) + if self.pdt_registry is not None and isinstance(result, ProviderDefinedType): + hydrated = self.pdt_registry.hydrate(result) + if not isinstance(hydrated, ProviderDefinedType): + return hydrated + result = hydrated + if isinstance(result, ProviderDefinedType) and result.name in _pdt_decorated_types: + return self._hydrate_decorated(result) + return result + + def _hydrate_decorated(self, pdt): + """Hydrate a ProviderDefinedType using a @provider_defined decorated class.""" + cls = _pdt_decorated_types[pdt.name] + props = {} + for k, v in pdt.properties.items(): + if isinstance(v, ProviderDefinedType) and v.name in _pdt_decorated_types: + props[k] = self._hydrate_decorated(v) + elif self.pdt_registry is not None and isinstance(v, ProviderDefinedType): + props[k] = self.pdt_registry.hydrate(v) + else: + props[k] = v + obj = cls.__new__(cls) + for k, v in props.items(): + setattr(obj, k, v) + return obj class _GraphBinaryTypeIO(object, metaclass=GraphBinaryTypeType): @@ -921,4 +947,26 @@ def dictify(cls, obj, writer, to_extend, as_value=False, nullable=True): def objectify(cls, buff, reader, nullable=True): return cls.is_null(buff, reader, lambda b, r: Marker.of(int8_unpack(b.read(1))), - nullable) \ No newline at end of file + nullable) + + +class ProviderDefinedTypeIO(_GraphBinaryTypeIO): + python_type = ProviderDefinedType + graphbinary_type = DataType.composite_pdt + + @classmethod + def dictify(cls, obj, writer, to_extend, as_value=False, nullable=True): + cls.prefix_bytes(cls.graphbinary_type, as_value, nullable, to_extend) + StringIO.dictify(obj.name, writer, to_extend) + MapIO.dictify(obj.properties, writer, to_extend) + return to_extend + + @classmethod + def objectify(cls, buff, reader, nullable=True): + return cls.is_null(buff, reader, cls._read_pdt, nullable) + + @classmethod + def _read_pdt(cls, b, r): + name = r.read_object(b) + properties = r.read_object(b) + return ProviderDefinedType(name, properties) \ No newline at end of file diff --git a/gremlin-python/src/main/python/tests/integration/conftest.py b/gremlin-python/src/main/python/tests/integration/conftest.py index 76215bf24ce..b3fecc10ef4 100644 --- a/gremlin-python/src/main/python/tests/integration/conftest.py +++ b/gremlin-python/src/main/python/tests/integration/conftest.py @@ -18,6 +18,7 @@ # import concurrent.futures +from collections import namedtuple from json import dumps import os import ssl @@ -40,6 +41,14 @@ verbose_logging = False +# Shared namedtuple used by remote_connection_with_registry fixture and its tests. +RegistryPoint = namedtuple('RegistryPoint', ['x', 'y']) + + +@pytest.fixture +def registry_point_class(): + return RegistryPoint + logging.basicConfig(format='%(asctime)s [%(levelname)8s] [%(filename)15s:%(lineno)d - %(funcName)10s()] - %(message)s', level=logging.DEBUG if verbose_logging else logging.INFO) @@ -211,6 +220,27 @@ def fin(): return client +@pytest.fixture +def remote_connection_with_registry(request): + from gremlin_python.structure.graph import ProviderDefinedTypeRegistry + + registry = ProviderDefinedTypeRegistry() + registry.register('RegistryPoint', + deserialize_fn=lambda props: RegistryPoint(x=props['x'], y=props['y']), + serialize_fn=lambda p: {'x': p.x, 'y': p.y}, + target_class=RegistryPoint) + try: + remote_conn = DriverRemoteConnection(anonymous_url, 'gmodern', pdt_registry=registry) + except OSError: + pytest.skip('Gremlin Server is not running') + else: + def fin(): + remote_conn.close() + + request.addfinalizer(fin) + return remote_conn + + def json_interceptor(request): request['headers']['content-type'] = "application/json" request['payload'] = dumps({"gremlin": "g.inject(2)", "g": "g"}) diff --git a/gremlin-python/src/main/python/tests/integration/driver/test_client.py b/gremlin-python/src/main/python/tests/integration/driver/test_client.py index 00e427cd589..1fe09deaf86 100644 --- a/gremlin-python/src/main/python/tests/integration/driver/test_client.py +++ b/gremlin-python/src/main/python/tests/integration/driver/test_client.py @@ -26,6 +26,7 @@ from gremlin_python.driver.connection import GremlinServerError from gremlin_python.driver.request import RequestMessage from gremlin_python.driver.serializer import GraphBinarySerializersV4 +from gremlin_python.structure.graph import ProviderDefinedType from gremlin_python.process.graph_traversal import __, GraphTraversalSource from gremlin_python.process.traversal import TraversalStrategies, GValue from gremlin_python.process.strategies import OptionsStrategy @@ -554,3 +555,60 @@ def test_response_serializer_never_None(): def test_serializer_and_interceptor_forwarded(client_with_interceptor): result = client_with_interceptor.submit("g.inject(1)").next() assert 2 == result # interceptor changes request to g.inject(2) + +def test_simple_pdt_round_trip(client): + """Inject and retrieve a simple Point PDT.""" + results = client.submit( + "g.inject(PDT(\"Point\", [\"x\":1, \"y\":2]))" + ).all().result() + + assert len(results) == 1 + pdt = results[0] + assert isinstance(pdt, ProviderDefinedType) + assert pdt.name == 'Point' + assert pdt.properties['x'] == 1 + assert pdt.properties['y'] == 2 + + +def test_nested_pdt(client): + """Inject and retrieve a nested PDT (Person containing Address).""" + results = client.submit( + "g.inject(PDT(\"Person\", [\"name\":\"Alice\", \"age\":30, " + "\"address\":PDT(\"Address\", [\"street\":\"123 Main St\", \"city\":\"Springfield\", \"zip\":\"12345\"])]))" + ).all().result() + + assert len(results) == 1 + pdt = results[0] + assert isinstance(pdt, ProviderDefinedType) + assert pdt.name == 'Person' + assert pdt.properties['name'] == 'Alice' + assert pdt.properties['age'] == 30 + + address = pdt.properties['address'] + assert isinstance(address, ProviderDefinedType) + assert address.name == 'Address' + assert address.properties['street'] == '123 Main St' + assert address.properties['city'] == 'Springfield' + assert address.properties['zip'] == '12345' + + +def test_pdt_in_collection(client): + """Retrieve multiple PDTs as a list.""" + results = client.submit( + "g.inject([PDT(\"Point\", [\"x\":1, \"y\":2]), PDT(\"Point\", [\"x\":3, \"y\":4])])" + ).all().result() + + assert len(results) == 1 + pdt_list = results[0] + assert isinstance(pdt_list, list) + assert len(pdt_list) == 2 + + assert isinstance(pdt_list[0], ProviderDefinedType) + assert pdt_list[0].name == 'Point' + assert pdt_list[0].properties['x'] == 1 + assert pdt_list[0].properties['y'] == 2 + + assert isinstance(pdt_list[1], ProviderDefinedType) + assert pdt_list[1].name == 'Point' + assert pdt_list[1].properties['x'] == 3 + assert pdt_list[1].properties['y'] == 4 diff --git a/gremlin-python/src/main/python/tests/integration/driver/test_driver_remote_connection.py b/gremlin-python/src/main/python/tests/integration/driver/test_driver_remote_connection.py index 604190b0c65..91fc23a4911 100644 --- a/gremlin-python/src/main/python/tests/integration/driver/test_driver_remote_connection.py +++ b/gremlin-python/src/main/python/tests/integration/driver/test_driver_remote_connection.py @@ -26,7 +26,7 @@ from gremlin_python.process.traversal import TraversalStrategy, P, Order, T, DT, GValue, Cardinality, Scope from gremlin_python.process.graph_traversal import __ from gremlin_python.process.anonymous_traversal import traversal -from gremlin_python.structure.graph import Vertex, Edge, Graph +from gremlin_python.structure.graph import Vertex, Edge, Graph, ProviderDefinedType, provider_defined from gremlin_python.process.strategies import SubgraphStrategy, SeedStrategy, ReservedKeysVerificationStrategy from gremlin_python.structure.io.util import HashableDict from gremlin_python.driver.connection import GremlinServerError @@ -289,3 +289,34 @@ def test_forwards_interceptor_serializers(self, remote_connection_with_intercept g = traversal().with_(remote_connection_with_interceptor) result = g.inject(1).next() assert 2 == result # interceptor changes request to g.inject(2) + + def test_pdt_round_trip_via_traversal(self, remote_connection): + g = traversal().with_(remote_connection) + pdt = ProviderDefinedType('Point', {'x': 1, 'y': 2}) + result = g.inject(pdt).next() + assert isinstance(result, ProviderDefinedType) + assert result.name == 'Point' + assert result.properties == {'x': 1, 'y': 2} + + def test_pdt_registry_round_trip_via_traversal(self, remote_connection_with_registry, registry_point_class): + g = traversal().with_(remote_connection_with_registry) + point = registry_point_class(x=10, y=20) + result = g.inject(point).next() + # Registry auto-dehydrates on send and auto-hydrates on receive + assert isinstance(result, registry_point_class) + assert result.x == 10 + assert result.y == 20 + + def test_pdt_annotation_auto_dehydrate_via_traversal(self, remote_connection): + @provider_defined(name='TestPoint') + class TestPoint: + def __init__(self, x, y): + self.x = x + self.y = y + + g = traversal().with_(remote_connection) + point = TestPoint(5, 10) + result = g.inject(point).next() + assert isinstance(result, TestPoint) + assert result.x == 5 + assert result.y == 10 diff --git a/gremlin-python/src/main/python/tests/unit/process/test_gremlin_lang.py b/gremlin-python/src/main/python/tests/unit/process/test_gremlin_lang.py index edf0515317b..74a386163d3 100644 --- a/gremlin-python/src/main/python/tests/unit/process/test_gremlin_lang.py +++ b/gremlin-python/src/main/python/tests/unit/process/test_gremlin_lang.py @@ -579,3 +579,38 @@ def test_convert_parameters_to_string_escaped_string(self): result = GremlinLang.convert_parameters_to_string({'name': "it's a test"}) assert "'name'" in result assert "it" in result + + def test_provider_defined_auto_dehydration(self): + from gremlin_python.structure.graph import ProviderDefinedType, provider_defined + g = traversal().with_(None) + + @provider_defined(name="com.example.Point") + class Point: + def __init__(self, x, y): + self.x = x + self.y = y + + p = Point(1, 2) + gremlin = g.inject(p).gremlin_lang.get_gremlin() + assert "PDT('com.example.Point',['x':1,'y':2])" in gremlin + + def test_pdt_special_characters_in_name(self): + from gremlin_python.structure.graph import ProviderDefinedType + g = traversal().with_(None) + + pdt = ProviderDefinedType('say"hello"', {'v': 1}) + gremlin = g.inject(pdt).gremlin_lang.get_gremlin() + assert "PDT('say\"hello\"',['v':1])" in gremlin + + pdt2 = ProviderDefinedType('back\\slash', {'v': 1}) + gremlin2 = g.inject(pdt2).gremlin_lang.get_gremlin() + assert "PDT('back\\\\slash',['v':1])" in gremlin2 + + def test_pdt_nested(self): + from gremlin_python.structure.graph import ProviderDefinedType + g = traversal().with_(None) + + inner = ProviderDefinedType('Inner', {'v': 1}) + outer = ProviderDefinedType('Outer', {'inner': inner}) + gremlin = g.inject(outer).gremlin_lang.get_gremlin() + assert "PDT('Outer',['inner':PDT('Inner',['v':1])])" in gremlin diff --git a/gremlin-python/src/main/python/tests/unit/structure/io/test_graphbinaryV4.py b/gremlin-python/src/main/python/tests/unit/structure/io/test_graphbinaryV4.py index c1f2ac540ee..21b9cf7173c 100644 --- a/gremlin-python/src/main/python/tests/unit/structure/io/test_graphbinaryV4.py +++ b/gremlin-python/src/main/python/tests/unit/structure/io/test_graphbinaryV4.py @@ -23,7 +23,7 @@ from datetime import datetime, timedelta, timezone from gremlin_python.statics import long, bigint, BigDecimal, SingleByte, SingleChar -from gremlin_python.structure.graph import Graph, Vertex, Edge, Property, VertexProperty, Path +from gremlin_python.structure.graph import Graph, Vertex, Edge, Property, VertexProperty, Path, ProviderDefinedType from gremlin_python.structure.io.graphbinaryV4 import GraphBinaryWriter, GraphBinaryReader from gremlin_python.process.traversal import Direction from gremlin_python.structure.io.util import Marker @@ -316,3 +316,25 @@ def test_graph(self): assert len(re1.properties) == 1 assert re1.properties[0].key == "weight" assert re1.properties[0].value == 0.5 + + def test_provider_defined_type(self): + pdt = ProviderDefinedType('Point', {'x': 1, 'y': 2}) + result = self.graphbinary_reader.read_object(self.graphbinary_writer.write_object(pdt)) + assert isinstance(result, ProviderDefinedType) + assert result.name == 'Point' + assert result.properties == {'x': 1, 'y': 2} + + def test_provider_defined_type_nested(self): + inner = ProviderDefinedType('Address', {'street': 'Main'}) + outer = ProviderDefinedType('Person', {'name': 'Alice', 'address': inner}) + result = self.graphbinary_reader.read_object(self.graphbinary_writer.write_object(outer)) + assert result.name == 'Person' + assert result.properties['name'] == 'Alice' + assert isinstance(result.properties['address'], ProviderDefinedType) + assert result.properties['address'].name == 'Address' + + def test_provider_defined_type_null_field(self): + pdt = ProviderDefinedType('NullableType', {'value': None, 'name': 'test'}) + result = self.graphbinary_reader.read_object(self.graphbinary_writer.write_object(pdt)) + assert result.properties['value'] is None + assert result.properties['name'] == 'test' diff --git a/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py b/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py new file mode 100644 index 00000000000..2d5a47cb08b --- /dev/null +++ b/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py @@ -0,0 +1,220 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" + +import pytest + +from gremlin_python.structure.graph import ProviderDefinedType, ProviderDefinedTypeRegistry, provider_defined +from gremlin_python.structure.io.graphbinaryV4 import GraphBinaryWriter, GraphBinaryReader + + +class TestProviderDefinedType(object): + graphbinary_writer = GraphBinaryWriter() + graphbinary_reader = GraphBinaryReader() + + def test_empty_name_rejected(self): + with pytest.raises(ValueError): + ProviderDefinedType("", {"x": 1}) + + def test_none_name_rejected(self): + with pytest.raises(ValueError): + ProviderDefinedType(None, {"x": 1}) + + +class TestProviderDefinedTypeRegistry(object): + + def test_hydrate_simple(self): + registry = ProviderDefinedTypeRegistry() + registry.register("com.example.Point", lambda props: (props["x"], props["y"])) + pdt = ProviderDefinedType("com.example.Point", {"x": 1.0, "y": 2.0}) + result = registry.hydrate(pdt) + assert result == (1.0, 2.0) + + def test_hydrate_no_adapter_returns_raw(self): + registry = ProviderDefinedTypeRegistry() + pdt = ProviderDefinedType("com.example.Unknown", {"a": 1}) + result = registry.hydrate(pdt) + assert result is pdt + + def test_hydrate_adapter_throws_falls_back(self): + registry = ProviderDefinedTypeRegistry() + registry.register("com.example.Bad", lambda props: 1 / 0) + pdt = ProviderDefinedType("com.example.Bad", {"x": 1}) + result = registry.hydrate(pdt) + assert result is pdt + + def test_hydrate_nested(self): + from collections import namedtuple + Inner = namedtuple("Inner", ["val"]) + Outer = namedtuple("Outer", ["child", "count"]) + registry = ProviderDefinedTypeRegistry() + registry.register("com.example.Inner", lambda props: Inner(props["val"].upper())) + registry.register("com.example.Outer", lambda props: Outer(props["child"], props["count"])) + inner = ProviderDefinedType("com.example.Inner", {"val": "hello"}) + outer = ProviderDefinedType("com.example.Outer", {"child": inner, "count": 42}) + result = registry.hydrate(outer) + assert result == Outer(Inner("HELLO"), 42) + + def test_hydrate_non_pdt_passthrough(self): + registry = ProviderDefinedTypeRegistry() + assert registry.hydrate("plain string") == "plain string" + assert registry.hydrate(42) == 42 + + def test_dehydrate_simple(self): + from collections import namedtuple + Point = namedtuple("Point", ["x", "y"]) + registry = ProviderDefinedTypeRegistry() + registry.register("com.example.Point", + deserialize_fn=lambda props: Point(props["x"], props["y"]), + serialize_fn=lambda p: {"x": p.x, "y": p.y}, + target_class=Point) + adapter = registry.get_adapter_by_class(Point) + props = adapter['serialize'](Point(1.0, 2.0)) + assert props == {"x": 1.0, "y": 2.0} + + def test_dehydrate_no_adapter_returns_none(self): + registry = ProviderDefinedTypeRegistry() + assert registry.get_adapter_by_class(str) is None + + def test_dehydrate_no_serialize_fn_returns_none(self): + registry = ProviderDefinedTypeRegistry() + registry.register("com.example.Thing", deserialize_fn=lambda props: props, target_class=dict) + adapter = registry.get_adapter_by_class(dict) + assert adapter['serialize'] is None + + +class TestProviderDefinedTypeRegistryBuild(object): + + def test_build_returns_registry_with_no_entry_points(self): + registry = ProviderDefinedTypeRegistry.build() + assert isinstance(registry, ProviderDefinedTypeRegistry) + + def test_build_loads_entry_point(self): + from unittest.mock import patch, MagicMock + + mock_ep = MagicMock() + mock_ep.name = "mock_adapter" + mock_ep.load.return_value = lambda reg: reg.register("com.mock.Type", lambda props: props) + + with patch("importlib.metadata.entry_points") as mock_entry_points: + import sys + if sys.version_info >= (3, 10): + mock_entry_points.return_value = [mock_ep] + else: + mock_entry_points.return_value = {'tinkerpop.pdt': [mock_ep]} + + registry = ProviderDefinedTypeRegistry.build() + assert "com.mock.Type" in registry._adapters_by_name + + def test_build_handles_failing_entry_point(self): + from unittest.mock import patch, MagicMock + + mock_ep = MagicMock() + mock_ep.name = "bad_adapter" + mock_ep.load.side_effect = RuntimeError("boom") + + with patch("importlib.metadata.entry_points") as mock_entry_points: + import sys + if sys.version_info >= (3, 10): + mock_entry_points.return_value = [mock_ep] + else: + mock_entry_points.return_value = {'tinkerpop.pdt': [mock_ep]} + + registry = ProviderDefinedTypeRegistry.build() + assert isinstance(registry, ProviderDefinedTypeRegistry) + assert len(registry._adapters_by_name) == 0 + + +class TestReaderAutoHydration(object): + + def test_reader_auto_hydrates_with_registry(self): + registry = ProviderDefinedTypeRegistry() + registry.register("com.example.Point", lambda props: {"x": props["x"], "y": props["y"], "hydrated": True}) + writer = GraphBinaryWriter() + reader = GraphBinaryReader(pdt_registry=registry) + + pdt = ProviderDefinedType("com.example.Point", {"x": 1.0, "y": 2.0}) + result = reader.read_object(writer.write_object(pdt)) + assert result == {"x": 1.0, "y": 2.0, "hydrated": True} + + def test_reader_no_registry_returns_raw_pdt(self): + writer = GraphBinaryWriter() + reader = GraphBinaryReader() + + pdt = ProviderDefinedType("com.example.Unregistered", {"x": 1.0, "y": 2.0}) + result = reader.read_object(writer.write_object(pdt)) + assert isinstance(result, ProviderDefinedType) + assert result == pdt + + +class TestProviderDefinedDecorator(object): + + def test_decorator_sets_metadata_with_name(self): + @provider_defined(name="com.example.Point", included_fields=["x", "y"]) + class Point: + pass + + assert Point._pdt_name == "com.example.Point" + assert Point._pdt_included_fields == ["x", "y"] + assert Point._pdt_excluded_fields is None + + def test_decorator_defaults_to_class_name(self): + @provider_defined() + class MyType: + pass + + assert MyType._pdt_name == "MyType" + assert MyType._pdt_included_fields is None + assert MyType._pdt_excluded_fields is None + + def test_decorator_excluded_fields(self): + @provider_defined(excluded_fields=["internal"]) + class Foo: + pass + + assert Foo._pdt_excluded_fields == ["internal"] + + +class TestPdtRegistryWiring(object): + + def test_serializer_passes_registry_to_reader(self): + pytest.importorskip("aiohttp") + from gremlin_python.driver.serializer import GraphBinarySerializersV4 + registry = ProviderDefinedTypeRegistry() + s = GraphBinarySerializersV4(pdt_registry=registry) + assert s._graphbinary_reader.pdt_registry is registry + + def test_client_passes_registry_to_serializers(self): + pytest.importorskip("aiohttp") + from unittest.mock import patch + from gremlin_python.driver.client import Client + registry = ProviderDefinedTypeRegistry() + with patch.object(Client, '_fill_pool'): + c = Client("ws://localhost:8182/gremlin", "g", pdt_registry=registry) + assert c._request_serializer._graphbinary_reader.pdt_registry is registry + assert c._response_serializer._graphbinary_reader.pdt_registry is registry + + def test_driver_remote_connection_passes_registry(self): + pytest.importorskip("aiohttp") + from unittest.mock import patch + from gremlin_python.driver.client import Client + from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection + registry = ProviderDefinedTypeRegistry() + with patch.object(Client, '_fill_pool'): + drc = DriverRemoteConnection("ws://localhost:8182/gremlin", "g", pdt_registry=registry) + assert drc._client._response_serializer._graphbinary_reader.pdt_registry is registry diff --git a/gremlin-server/pom.xml b/gremlin-server/pom.xml index df4c5e6da6b..9b40be61ee3 100644 --- a/gremlin-server/pom.xml +++ b/gremlin-server/pom.xml @@ -173,6 +173,16 @@ limitations under the License. + + maven-jar-plugin + + + + test-jar + + + + org.apache.maven.plugins maven-surefire-plugin diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java index 14b185142d6..03fafbd3f18 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java @@ -30,10 +30,17 @@ import org.apache.tinkerpop.gremlin.driver.exception.NoHostAvailableException; import org.apache.tinkerpop.gremlin.driver.exception.ResponseException; import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; +import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.jsr223.ScriptFileGremlinPlugin; +import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; import org.apache.tinkerpop.gremlin.server.channel.HttpChannelizer; import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.structure.io.Storage; +import org.apache.tinkerpop.gremlin.structure.io.binary.TypeSerializerRegistry; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeAdapter; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import org.apache.tinkerpop.gremlin.structure.util.detached.DetachedVertex; import org.apache.tinkerpop.gremlin.util.ExceptionHelper; import org.apache.tinkerpop.gremlin.util.TimeUtil; @@ -68,6 +75,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource.traversal; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.IsInstanceOf.instanceOf; @@ -1264,4 +1272,156 @@ public void shouldReturnUuid() throws Exception { cluster.close(); } } + + @Test + public void shouldRoundTripRawPdtViaTraversal() throws Exception { + final Cluster cluster = TestClientFactory.build().create(); + try { + final GraphTraversalSource g = traversal().with(DriverRemoteConnection.using(cluster)); + final Map props = new HashMap<>(); + props.put("x", 1); + props.put("y", 2); + final ProviderDefinedType pdt = new ProviderDefinedType("TestPoint", props); + final Object result = g.inject(pdt).next(); + + assertTrue(result instanceof ProviderDefinedType); + final ProviderDefinedType r = (ProviderDefinedType) result; + assertEquals("TestPoint", r.getName()); + assertEquals(1, r.getProperties().get("x")); + assertEquals(2, r.getProperties().get("y")); + } finally { + cluster.close(); + } + } + + @Test + public void shouldRoundTripRegistryPdtViaTraversal() throws Exception { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new TestPointAdapter()); + + final Cluster cluster = TestClientFactory.build() + .serializer(new GraphBinaryMessageSerializerV4(TypeSerializerRegistry.INSTANCE, registry)) + .create(); + try { + final DriverRemoteConnection connection = DriverRemoteConnection.using(cluster); + connection.setPdtRegistry(registry); + final GraphTraversalSource g = traversal().with(connection); + + final Object result = g.inject(new TestPoint(5, 10)).next(); + + assertTrue("Expected TestPoint but got: " + result.getClass().getName(), result instanceof TestPoint); + assertEquals(5, ((TestPoint) result).x); + assertEquals(10, ((TestPoint) result).y); + } finally { + cluster.close(); + } + } + + @Test + public void shouldRoundTripAnnotatedPdtViaTraversal() throws Exception { + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(TestAnnotatedPoint.class); + + final Cluster cluster = TestClientFactory.build() + .serializer(new GraphBinaryMessageSerializerV4(TypeSerializerRegistry.INSTANCE, registry)) + .create(); + try { + final DriverRemoteConnection connection = DriverRemoteConnection.using(cluster); + connection.setPdtRegistry(registry); + final GraphTraversalSource g = traversal().with(connection); + + final Object result = g.inject(new TestAnnotatedPoint(3, 7)).next(); + + assertTrue("Expected TestAnnotatedPoint but got: " + result.getClass().getName(), + result instanceof TestAnnotatedPoint); + assertEquals(3, ((TestAnnotatedPoint) result).x); + assertEquals(7, ((TestAnnotatedPoint) result).y); + } finally { + cluster.close(); + } + } + + @Test + public void shouldHydratePdtViaRegistryFromDriverResult() throws Exception { + final Cluster cluster = TestClientFactory.build().create(); + final Client client = cluster.connect(); + try { + final List results = client.submit( + "g.inject(new org.apache.tinkerpop.gremlin.server.pdt.Point(9, 11))", + groovyRequestOptions).all().get(); + + assertEquals(1, results.size()); + final ProviderDefinedType raw = (ProviderDefinedType) results.get(0).getObject(); + + final ProviderDefinedTypeRegistry registry = ProviderDefinedTypeRegistry.empty(); + registry.register(new TestPointAdapter()); + final Object hydrated = registry.hydrate(raw); + + assertTrue(hydrated instanceof TestPoint); + assertEquals(9, ((TestPoint) hydrated).x); + assertEquals(11, ((TestPoint) hydrated).y); + } finally { + cluster.close(); + } + } + + @Test + public void shouldStorePdtAsOriginalObjectInTinkerGraph() throws Exception { + final Cluster cluster = TestClientFactory.build().create(); + final Client client = cluster.connect(); + try { + client.submit( + "g.addV('location').property('point', new org.apache.tinkerpop.gremlin.server.pdt.Point(3, 4)).iterate()", + groovyRequestOptions).all().get(); + + final List results = client.submit( + "g.V().hasLabel('location').values('point')", + groovyRequestOptions).all().get(); + + assertEquals(1, results.size()); + // Point is @ProviderDefined, so GraphBinary serializes the stored object as a ProviderDefinedType + // on the wire. Without a client-side registry it is received as a raw ProviderDefinedType. + final Object value = results.get(0).getObject(); + assertTrue("Expected ProviderDefinedType but got: " + value.getClass().getName(), + value instanceof ProviderDefinedType); + final ProviderDefinedType pdt = (ProviderDefinedType) value; + assertEquals("Point", pdt.getName()); + assertEquals(3, pdt.getProperties().get("x")); + assertEquals(4, pdt.getProperties().get("y")); + } finally { + // cleanup + client.submit("g.V().hasLabel('location').drop().iterate()", groovyRequestOptions).all().get(); + cluster.close(); + } + } + + // --- PDT helper types --- + + static class TestPoint { + final int x, y; + TestPoint(final int x, final int y) { this.x = x; this.y = y; } + } + + static class TestPointAdapter implements ProviderDefinedTypeAdapter { + // TestPoint is the client-side representation of the server-side @ProviderDefined "Point" type, + // so the adapter's type name matches the server type name "Point". + @Override public String typeName() { return "Point"; } + @Override public Class targetClass() { return TestPoint.class; } + @Override public Map toProperties(final TestPoint obj) { + final Map m = new HashMap<>(); + m.put("x", obj.x); + m.put("y", obj.y); + return m; + } + @Override public TestPoint fromProperties(final Map props) { + return new TestPoint(((Number) props.get("x")).intValue(), ((Number) props.get("y")).intValue()); + } + } + + @ProviderDefined(name = "TestAnnotatedPoint") + static class TestAnnotatedPoint { + public int x, y; + public TestAnnotatedPoint() {} + TestAnnotatedPoint(final int x, final int y) { this.x = x; this.y = y; } + } } diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerSerializationIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerSerializationIntegrateTest.java index 300b1e1e745..28a81731e3b 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerSerializationIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerSerializationIntegrateTest.java @@ -18,9 +18,17 @@ */ package org.apache.tinkerpop.gremlin.server; +import org.apache.http.Consts; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.util.EntityUtils; import org.apache.tinkerpop.gremlin.driver.Client; import org.apache.tinkerpop.gremlin.driver.Cluster; import org.apache.tinkerpop.gremlin.driver.RequestOptions; +import org.apache.tinkerpop.gremlin.driver.Result; import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource; import org.apache.tinkerpop.gremlin.process.traversal.Path; @@ -29,10 +37,14 @@ import org.apache.tinkerpop.gremlin.structure.Property; import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.structure.VertexProperty; +import org.apache.tinkerpop.gremlin.structure.io.graphson.GraphSONTokens; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType; import org.apache.tinkerpop.gremlin.util.Tokens; import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils; import org.apache.tinkerpop.gremlin.util.ser.AbstractMessageSerializer; import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; +import org.apache.tinkerpop.shaded.jackson.databind.JsonNode; +import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -42,6 +54,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.List; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; @@ -257,6 +270,79 @@ private void assertEdgeWithProperties(final Edge edge) { assertEquals(0.5, edge.property("weight").value()); } + @Test + public void shouldRoundTripSimplePointPdt() throws Exception { + final List results = client.submit( + "g.inject(PDT(\"Point\", [\"x\":1, \"y\":2]))").all().get(); + + assertEquals(1, results.size()); + final ProviderDefinedType pdt = (ProviderDefinedType) results.get(0).getObject(); + assertEquals("Point", pdt.getName()); + assertEquals(1, pdt.getProperties().get("x")); + assertEquals(2, pdt.getProperties().get("y")); + } + + @Test + public void shouldRoundTripNestedPdt() throws Exception { + final List results = client.submit( + "g.inject(PDT(\"Person\", [\"name\":\"Alice\", \"age\":30, " + + "\"address\":PDT(\"Address\", [\"street\":\"123 Main St\", \"city\":\"Springfield\", \"zip\":\"12345\"])]))").all().get(); + + assertEquals(1, results.size()); + final ProviderDefinedType person = (ProviderDefinedType) results.get(0).getObject(); + assertEquals("Person", person.getName()); + assertEquals("Alice", person.getProperties().get("name")); + assertEquals(30, person.getProperties().get("age")); + + final ProviderDefinedType address = (ProviderDefinedType) person.getProperties().get("address"); + assertEquals("Address", address.getName()); + assertEquals("123 Main St", address.getProperties().get("street")); + assertEquals("Springfield", address.getProperties().get("city")); + } + + @Test + public void shouldRoundTripPdtInCollection() throws Exception { + final List results = client.submit( + "g.inject([PDT(\"Point\", [\"x\":1, \"y\":2]), PDT(\"Point\", [\"x\":3, \"y\":4])])").all().get(); + + assertEquals(1, results.size()); + final List list = (List) results.get(0).getObject(); + assertEquals(2, list.size()); + + final ProviderDefinedType p1 = (ProviderDefinedType) list.get(0); + assertEquals("Point", p1.getName()); + assertEquals(1, p1.getProperties().get("x")); + assertEquals(2, p1.getProperties().get("y")); + + final ProviderDefinedType p2 = (ProviderDefinedType) list.get(1); + assertEquals("Point", p2.getName()); + assertEquals(3, p2.getProperties().get("x")); + assertEquals(4, p2.getProperties().get("y")); + } + + @Test + public void shouldReturnPdtAsGraphSONCompositePdtInHttpResponse() throws Exception { + final CloseableHttpClient httpclient = HttpClients.createDefault(); + final HttpPost httppost = new HttpPost(TestClientFactory.createURLString()); + httppost.addHeader("Content-Type", "application/json"); + httppost.addHeader("Accept", "application/json"); + httppost.setEntity(new StringEntity( + "{\"gremlin\":\"g.inject(org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedType.from(new org.apache.tinkerpop.gremlin.server.pdt.Point(1, 2)))\",\"language\":\"gremlin-groovy\"}", + Consts.UTF_8)); + + try (final CloseableHttpResponse response = httpclient.execute(httppost)) { + assertEquals(200, response.getStatusLine().getStatusCode()); + final JsonNode root = new ObjectMapper().readTree(EntityUtils.toString(response.getEntity())); + final JsonNode pdtNode = root.get("result").get("data").get(GraphSONTokens.VALUEPROP).get(0); + + assertEquals("g:CompositePdt", pdtNode.get("@type").asText()); + final JsonNode value = pdtNode.get(GraphSONTokens.VALUEPROP); + assertEquals("Point", value.get("type").asText()); + assertEquals(1, value.get("fields").get("x").get(GraphSONTokens.VALUEPROP).intValue()); + assertEquals(2, value.get("fields").get("y").get(GraphSONTokens.VALUEPROP).intValue()); + } + } + private void assertPathElementsWithProperties(final Path p) { // expect a V-E-V path assertEquals(3, p.size()); diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Address.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Address.java new file mode 100644 index 00000000000..cc01cdd8a75 --- /dev/null +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Address.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.server.pdt; + +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; + +@ProviderDefined(name = "Address") +public class Address { + public String street; + public String city; + public String zip; + + public Address(final String street, final String city, final String zip) { + this.street = street; + this.city = city; + this.zip = zip; + } +} diff --git a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePerson.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Person.java similarity index 57% rename from gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePerson.java rename to gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Person.java index 2af646b6de0..96b878f950e 100644 --- a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePerson.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Person.java @@ -16,31 +16,19 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.tinkerpop.gremlin.util.ser.binary.types.sample; +package org.apache.tinkerpop.gremlin.server.pdt; -import java.time.OffsetDateTime; -import java.util.Objects; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; -/** - * A sample custom data type containing few properties. - */ -public class SamplePerson { - private final String name; - private final OffsetDateTime birthDate; - - SamplePerson(final String name, final OffsetDateTime birthDate) { - Objects.requireNonNull(name); - Objects.requireNonNull(birthDate); +@ProviderDefined(name = "Person") +public class Person { + public String name; + public int age; + public Address address; + public Person(final String name, final int age, final Address address) { this.name = name; - this.birthDate = birthDate; - } - - public String getName() { - return name; - } - - public OffsetDateTime getBirthDate() { - return birthDate; + this.age = age; + this.address = address; } } diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Point.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Point.java new file mode 100644 index 00000000000..1bb8470061c --- /dev/null +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/pdt/Point.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.server.pdt; + +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefined; + +/** + * A simple test class annotated with {@link ProviderDefined} for PDT integration testing. + */ +@ProviderDefined(name = "Point") +public class Point { + public int x; + public int y; + + public Point(final int x, final int y) { + this.x = x; + this.y = y; + } +} diff --git a/gremlin-util/src/main/java/org/apache/tinkerpop/gremlin/util/ser/GraphBinaryMessageSerializerV4.java b/gremlin-util/src/main/java/org/apache/tinkerpop/gremlin/util/ser/GraphBinaryMessageSerializerV4.java index e9c7a97f82e..795712ff769 100644 --- a/gremlin-util/src/main/java/org/apache/tinkerpop/gremlin/util/ser/GraphBinaryMessageSerializerV4.java +++ b/gremlin-util/src/main/java/org/apache/tinkerpop/gremlin/util/ser/GraphBinaryMessageSerializerV4.java @@ -23,24 +23,20 @@ import io.netty.handler.codec.http.HttpResponseStatus; import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.io.Buffer; -import org.apache.tinkerpop.gremlin.structure.io.IoRegistry; -import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryIo; import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryMapper; import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryReader; import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryWriter; import org.apache.tinkerpop.gremlin.structure.io.binary.Marker; import org.apache.tinkerpop.gremlin.structure.io.binary.TypeSerializerRegistry; -import org.apache.tinkerpop.gremlin.structure.io.binary.types.CustomTypeSerializer; +import org.apache.tinkerpop.gremlin.structure.io.pdt.ProviderDefinedTypeRegistry; import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.apache.tinkerpop.gremlin.util.message.ResponseMessage; import org.apache.tinkerpop.gremlin.util.message.ResponseStatus; import org.apache.tinkerpop.gremlin.util.ser.binary.RequestMessageSerializer; -import org.javatuples.Pair; import org.javatuples.Triplet; import java.io.IOException; import java.lang.reflect.Constructor; -import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collections; import java.util.EnumSet; @@ -48,7 +44,6 @@ import java.util.Map; public class GraphBinaryMessageSerializerV4 extends AbstractMessageSerializer { - public static final String TOKEN_CUSTOM = "custom"; public static final String TOKEN_BUILDER = "builder"; private GraphBinaryReader reader; @@ -74,6 +69,14 @@ public GraphBinaryMessageSerializerV4(final TypeSerializerRegistry registry) { requestSerializer = new RequestMessageSerializer(); } + public GraphBinaryMessageSerializerV4(final TypeSerializerRegistry registry, final ProviderDefinedTypeRegistry pdtRegistry) { + reader = new GraphBinaryReader(registry, pdtRegistry); + writer = new GraphBinaryWriter(registry); + mapper = new GraphBinaryMapper(writer, reader); + + requestSerializer = new RequestMessageSerializer(); + } + public GraphBinaryMessageSerializerV4(final TypeSerializerRegistry.Builder builder) { this(builder.create()); } @@ -100,27 +103,6 @@ public void configure(final Map config, final Map builder = TypeSerializerRegistry.build(); } - final List classNameList = getListStringFromConfig(TOKEN_IO_REGISTRIES, config); - classNameList.forEach(className -> { - try { - final Class clazz = Class.forName(className); - try { - final Method instanceMethod = tryInstanceMethod(clazz); - final IoRegistry ioreg = (IoRegistry) instanceMethod.invoke(null); - final List> classSerializers = ioreg.find(GraphBinaryIo.class, CustomTypeSerializer.class); - for (Pair cs : classSerializers) { - builder.addCustomType(cs.getValue0(), cs.getValue1()); - } - } catch (Exception methodex) { - throw new IllegalStateException(String.format("Could not instantiate IoRegistry from an instance() method on %s", className), methodex); - } - } catch (Exception ex) { - throw new IllegalStateException(ex); - } - }); - - addCustomClasses(config, builder); - final TypeSerializerRegistry registry = builder.create(); reader = new GraphBinaryReader(registry); writer = new GraphBinaryWriter(registry); @@ -133,34 +115,6 @@ public String[] mimeTypesSupported() { return new String[] {MIME_TYPE}; } - private void addCustomClasses(final Map config, final TypeSerializerRegistry.Builder builder) { - final List classNameList = getListStringFromConfig(TOKEN_CUSTOM, config); - - classNameList.forEach(serializerDefinition -> { - final String className; - final String serializerName; - if (serializerDefinition.contains(";")) { - final String[] split = serializerDefinition.split(";"); - if (split.length != 2) - throw new IllegalStateException(String.format("Invalid format for serializer definition [%s] - expected ;", serializerDefinition)); - - className = split[0]; - serializerName = split[1]; - } else { - throw new IllegalStateException(String.format("Invalid format for serializer definition [%s] - expected ;", serializerDefinition)); - } - - try { - final Class clazz = Class.forName(className); - final Class serializerClazz = Class.forName(serializerName); - final CustomTypeSerializer serializer = (CustomTypeSerializer) serializerClazz.newInstance(); - builder.addCustomType(clazz, serializer); - } catch (Exception ex) { - throw new IllegalStateException("CustomTypeSerializer could not be instantiated", ex); - } - }); - } - @Override public ByteBuf serializeRequestAsBinary(RequestMessage requestMessage, ByteBufAllocator allocator) throws SerializationException { final ByteBuf buffer = allocator.buffer(); diff --git a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/TypeSerializerRegistryTest.java b/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/TypeSerializerRegistryTest.java index aaba1987a91..c5618c4eb94 100644 --- a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/TypeSerializerRegistryTest.java +++ b/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/TypeSerializerRegistryTest.java @@ -18,8 +18,6 @@ */ package org.apache.tinkerpop.gremlin.util.ser.binary; -import org.apache.tinkerpop.gremlin.util.ser.binary.types.sample.SamplePerson; -import org.apache.tinkerpop.gremlin.util.ser.binary.types.sample.SamplePersonSerializer; import org.apache.tinkerpop.gremlin.structure.io.binary.DataType; import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryReader; import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryWriter; @@ -93,18 +91,18 @@ public void shouldUseFallbackResolverWhenThereIsNoMatch() { String message = null; try { - registry.getSerializer(SamplePerson.class); + registry.getSerializer(StringBuilder.class); } catch (IOException ex) { message = ex.getMessage(); } - assertEquals("Serializer for type org.apache.tinkerpop.gremlin.util.ser.binary.types.sample.SamplePerson not found", message); + assertEquals("Serializer not found for type java.lang.StringBuilder. If this is a provider-defined type, annotate the class with @ProviderDefined.", message); assertEquals(1, called[0]); } @Test public void shouldUseFallbackResolverReturnValue() throws IOException { - TypeSerializer expected = new SamplePersonSerializer(); + TypeSerializer expected = new TestUUIDSerializer(); final int[] called = {0}; final TypeSerializerRegistry registry = TypeSerializerRegistry.build() .withFallbackResolver(t -> { @@ -112,7 +110,7 @@ public void shouldUseFallbackResolverReturnValue() throws IOException { return expected; }).create(); - TypeSerializer serializer = registry.getSerializer(SamplePerson.class); + TypeSerializer serializer = registry.getSerializer(StringBuilder.class); assertEquals(1, called[0]); assertSame(expected, serializer); } diff --git a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePersonSerializer.java b/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePersonSerializer.java deleted file mode 100644 index bf6c6cae79f..00000000000 --- a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePersonSerializer.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.tinkerpop.gremlin.util.ser.binary.types.sample; - -import org.apache.tinkerpop.gremlin.util.ser.SerializationException; -import org.apache.tinkerpop.gremlin.structure.io.Buffer; -import org.apache.tinkerpop.gremlin.structure.io.binary.DataType; -import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryReader; -import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryWriter; -import org.apache.tinkerpop.gremlin.structure.io.binary.types.CustomTypeSerializer; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.time.OffsetDateTime; - -/** - * A sample custom type serializer. - */ -public final class SamplePersonSerializer implements CustomTypeSerializer { - private final byte[] typeInfoBuffer = new byte[] { 0, 0, 0, 0 }; - - @Override - public String getTypeName() { - return "sampleProvider.SamplePerson"; - } - - @Override - public DataType getDataType() { - return DataType.CUSTOM; - } - - @Override - public SamplePerson read(final Buffer buffer, final GraphBinaryReader context) throws IOException { - // {custom type info}, {value_flag} and {value} - // No custom_type_info - if (buffer.readInt() != 0) { - throw new SerializationException("{custom_type_info} should not be provided for this custom type"); - } - - return readValue(buffer, context, true); - } - - @Override - public SamplePerson readValue(final Buffer buffer, final GraphBinaryReader context, final boolean nullable) throws IOException { - if (nullable) { - final byte valueFlag = buffer.readByte(); - if ((valueFlag & 1) == 1) { - return null; - } - } - - // Read the byte length of the value bytes - final int valueLength = buffer.readInt(); - - if (valueLength <= 0) { - throw new SerializationException(String.format("Unexpected value length: %d", valueLength)); - } - - if (valueLength > buffer.readableBytes()) { - throw new SerializationException( - String.format("Not enough readable bytes: %d (expected %d)", valueLength, buffer.readableBytes())); - } - - final String name = context.readValue(buffer, String.class, false); - final OffsetDateTime birthDate = context.readValue(buffer, OffsetDateTime.class, false); - - return new SamplePerson(name, birthDate); - } - - @Override - public void write(final SamplePerson value, final Buffer buffer, final GraphBinaryWriter context) throws IOException { - // Write {custom type info}, {value_flag} and {value} - buffer.writeBytes(typeInfoBuffer); - - writeValue(value, buffer, context, true); - } - - @Override - public void writeValue(final SamplePerson value, final Buffer buffer, final GraphBinaryWriter context, final boolean nullable) throws IOException { - if (value == null) { - if (!nullable) { - throw new SerializationException("Unexpected null value when nullable is false"); - } - - context.writeValueFlagNull(buffer); - return; - } - - if (nullable) { - context.writeValueFlagNone(buffer); - } - - final String name = value.getName(); - - // value_length = name_byte_length + name_bytes + long - buffer.writeInt(4 + name.getBytes(StandardCharsets.UTF_8).length + 8); - - context.writeValue(name, buffer, false); - context.writeValue(value.getBirthDate(), buffer, false); - } -} diff --git a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePersonSerializerTest.java b/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePersonSerializerTest.java deleted file mode 100644 index 83f64d2fcd1..00000000000 --- a/gremlin-util/src/test/java/org/apache/tinkerpop/gremlin/util/ser/binary/types/sample/SamplePersonSerializerTest.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.tinkerpop.gremlin.util.ser.binary.types.sample; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.handler.codec.http.HttpResponseStatus; -import org.apache.tinkerpop.gremlin.util.ser.NettyBufferFactory; -import org.apache.tinkerpop.gremlin.util.message.ResponseMessage; -import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; -import org.apache.tinkerpop.gremlin.structure.io.AbstractIoRegistry; -import org.apache.tinkerpop.gremlin.structure.io.Buffer; -import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryIo; -import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryReader; -import org.apache.tinkerpop.gremlin.structure.io.binary.GraphBinaryWriter; -import org.apache.tinkerpop.gremlin.structure.io.binary.TypeSerializerRegistry; -import org.junit.Ignore; -import org.junit.Test; - -import java.io.IOException; -import java.time.LocalDateTime; -import java.time.OffsetDateTime; -import java.time.ZoneOffset; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.apache.tinkerpop.gremlin.util.MockitoHamcrestMatcherAdapter.reflectionEquals; -import static org.apache.tinkerpop.gremlin.util.ser.AbstractMessageSerializer.TOKEN_IO_REGISTRIES; -import static org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4.TOKEN_CUSTOM; -import static org.hamcrest.MatcherAssert.assertThat; - -public class SamplePersonSerializerTest { - - private static final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - private static final NettyBufferFactory bufferFactory = new NettyBufferFactory(); - - @Test - public void shouldCustomSerializationWithPerson() throws IOException { - final GraphBinaryMessageSerializerV4 serializer = new GraphBinaryMessageSerializerV4( - TypeSerializerRegistry.build().addCustomType(SamplePerson.class, new SamplePersonSerializer()).create()); - assertPerson(serializer); - } - - @Test - public void shouldSerializePersonViaIoRegistry() throws IOException { - final GraphBinaryMessageSerializerV4 serializer = new GraphBinaryMessageSerializerV4(); - final Map config = new HashMap<>(); - config.put(TOKEN_IO_REGISTRIES, Collections.singletonList(CustomIoRegistry.class.getName())); - serializer.configure(config, Collections.emptyMap()); - - assertPerson(serializer); - } - - @Test - public void shouldSerializePersonViaCustom() throws IOException { - final GraphBinaryMessageSerializerV4 serializer = new GraphBinaryMessageSerializerV4(); - final Map config = new HashMap<>(); - config.put(TOKEN_CUSTOM, Collections.singletonList(String.format("%s;%s", - SamplePerson.class.getCanonicalName(), SamplePersonSerializer.class.getCanonicalName()))); - serializer.configure(config, Collections.emptyMap()); - - assertPerson(serializer); - } - - @Test - public void readValueAndWriteValueShouldBeSymmetric() throws IOException { - final TypeSerializerRegistry registry = TypeSerializerRegistry.build() - .addCustomType(SamplePerson.class, new SamplePersonSerializer()).create(); - final GraphBinaryReader reader = new GraphBinaryReader(registry); - final GraphBinaryWriter writer = new GraphBinaryWriter(registry); - - final SamplePerson person = new SamplePerson("Matias", - OffsetDateTime.of(LocalDateTime.of(2005, 8, 5, 1, 0), ZoneOffset.UTC)); - - for (boolean nullable: new boolean[] { true, false }) { - final Buffer buffer = bufferFactory.create(allocator.buffer()); - writer.writeValue(person, buffer, nullable); - final SamplePerson actual = reader.readValue(buffer, SamplePerson.class, nullable); - - assertThat(actual, reflectionEquals(person)); - buffer.release(); - } - } - - private void assertPerson(final GraphBinaryMessageSerializerV4 serializer) throws IOException { - final OffsetDateTime birthDate = OffsetDateTime.of(LocalDateTime.of(2010, 4, 29, 5, 30), ZoneOffset.UTC); - final SamplePerson person = new SamplePerson("Olivia", birthDate); - - final ByteBuf serialized = serializer.serializeResponseAsBinary( - ResponseMessage.build().result(Collections.singletonList(person)).code(HttpResponseStatus.OK).create(), allocator); - - final ResponseMessage deserialized = serializer.deserializeBinaryResponse(serialized); - - final SamplePerson actual = (SamplePerson) deserialized.getResult().getData().get(0); - assertThat(actual, reflectionEquals(person)); - } - - public static class CustomIoRegistry extends AbstractIoRegistry { - private static final CustomIoRegistry ioreg = new CustomIoRegistry(); - - private CustomIoRegistry() { - register(GraphBinaryIo.class, SamplePerson.class, new SamplePersonSerializer()); - } - - public static CustomIoRegistry instance() { - return ioreg; - } - } -}