diff --git a/rosidl_generator_py/CMakeLists.txt b/rosidl_generator_py/CMakeLists.txt index e5ff834f..6b384c86 100644 --- a/rosidl_generator_py/CMakeLists.txt +++ b/rosidl_generator_py/CMakeLists.txt @@ -69,6 +69,7 @@ if(BUILD_TESTING) ) ament_add_pytest_test(test_cli_extension test/test_cli_extension.py) + ament_add_pytest_test(test_template_imports_py test/test_template_imports.py) ament_add_pytest_test(test_property_py test/test_property.py APPEND_ENV "PYTHONPATH=${pythonpath}" diff --git a/rosidl_generator_py/resource/_msg.py.em b/rosidl_generator_py/resource/_msg.py.em index a8682df0..e093a3e6 100644 --- a/rosidl_generator_py/resource/_msg.py.em +++ b/rosidl_generator_py/resource/_msg.py.em @@ -30,6 +30,27 @@ from rosidl_parser.definition import NamespacedType from rosidl_parser.definition import SIGNED_INTEGER_TYPES from rosidl_parser.definition import UnboundedSequence from rosidl_parser.definition import UNSIGNED_INTEGER_TYPES + + +def get_importable_namespaced_type( + type_: NamespacedType, + action_goal_suffix: str = ACTION_GOAL_SUFFIX, + action_result_suffix: str = ACTION_RESULT_SUFFIX, + action_feedback_suffix: str = ACTION_FEEDBACK_SUFFIX, + camel_to_snake=convert_camel_case_to_lower_case_underscore, +) -> tuple[str, str]: + joined_type_namespaces = '.'.join(type_.namespaces) + if ( + type_.name.endswith(action_goal_suffix) or + type_.name.endswith(action_result_suffix) or + type_.name.endswith(action_feedback_suffix) + ): + action_name, _ = type_.name.rsplit('_', 1) + lower_case_name = camel_to_snake(action_name) + module_name = f'{joined_type_namespaces}._{lower_case_name}' + else: + module_name = joined_type_namespaces + return module_name, f'{module_name}.{type_.name}' }@ @{ import_type_checking = False @@ -182,22 +203,13 @@ for member in message.structure.members: type_.name.endswith(SERVICE_REQUEST_MESSAGE_SUFFIX) ): continue - if ( - type_.name.endswith(ACTION_GOAL_SUFFIX) or - type_.name.endswith(ACTION_RESULT_SUFFIX) or - type_.name.endswith(ACTION_FEEDBACK_SUFFIX) - ): - action_name, suffix = type_.name.rsplit('_', 1) - typename = (*type_.namespaces, action_name, action_name + '.' + suffix) - else: - typename = (*type_.namespaces, type_.name, type_.name) - importable_typesupports.add(typename) + importable_typesupports.add(get_importable_namespaced_type(type_)) }@ -@[for typename in sorted(importable_typesupports)]@ +@[for module_name, type_name in sorted(importable_typesupports)]@ - from @('.'.join(typename[:-2])) import @(typename[-2]) - if @(typename[-1])._TYPE_SUPPORT is None: - @(typename[-1]).__import_type_support__() + import @(module_name) + if @(type_name)._TYPE_SUPPORT is None: + @(type_name).__import_type_support__() @[end for]@ @@classmethod @@ -370,15 +382,10 @@ if isinstance(type_, AbstractNestedType): self.@(member.name) = @(member.name) if @(member.name) is not None else @(message.structure.namespaced_type.name).@(member.name.upper())__DEFAULT @[ else]@ @[ if isinstance(type_, NamespacedType) and not isinstance(member.type, AbstractSequence)]@ -@[ if ( - type_.name.endswith(ACTION_GOAL_SUFFIX) or - type_.name.endswith(ACTION_RESULT_SUFFIX) or - type_.name.endswith(ACTION_FEEDBACK_SUFFIX) - )]@ - from @('.'.join(type_.namespaces))._@(convert_camel_case_to_lower_case_underscore(type_.name.rsplit('_', 1)[0])) import @(type_.name) -@[ else]@ - from @('.'.join(type_.namespaces)) import @(type_.name) -@[ end if]@ +@{ +module_name, type_name = get_importable_namespaced_type(type_) +}@ + import @(module_name) @[ end if]@ @[ if isinstance(member.type, Array)]@ @[ if isinstance(type_, BasicType) and type_.typename == 'octet']@ @@ -392,7 +399,8 @@ if isinstance(type_, AbstractNestedType): else: self.@(member.name) = @(member.name) @[ else]@ - self.@(member.name) = @(member.name) if @(member.name) is not None else [@(get_python_type(type_))() for x in range(@(member.type.size))] +@{default_type = type_name if isinstance(type_, NamespacedType) else get_python_type(type_)}@ + self.@(member.name) = @(member.name) if @(member.name) is not None else [@(default_type)() for x in range(@(member.type.size))] @[ end if]@ @[ end if]@ @[ elif isinstance(member.type, AbstractSequence)]@ @@ -405,6 +413,8 @@ if isinstance(type_, AbstractNestedType): self.@(member.name) = @(member.name) if @(member.name) is not None else bytes([0]) @[ elif isinstance(type_, BasicType) and type_.typename in CHARACTER_TYPES]@ self.@(member.name) = @(member.name) if @(member.name) is not None else chr(0) +@[ elif isinstance(type_, NamespacedType)]@ + self.@(member.name) = @(member.name) if @(member.name) is not None else @(type_name)() @[ else]@ self.@(member.name) = @(member.name) if @(member.name) is not None else @(get_python_type(type_))() @[ end if]@ @@ -504,15 +514,10 @@ if isinstance(member.type, (Array, AbstractSequence)): @[ end if]@ @[ end if]@ @[ if isinstance(type_, NamespacedType)]@ -@[ if ( - type_.name.endswith(ACTION_GOAL_SUFFIX) or - type_.name.endswith(ACTION_RESULT_SUFFIX) or - type_.name.endswith(ACTION_FEEDBACK_SUFFIX) - )]@ - from @('.'.join(type_.namespaces))._@(convert_camel_case_to_lower_case_underscore(type_.name.rsplit('_', 1)[0])) import @(type_.name) -@[ else]@ - from @('.'.join(type_.namespaces)) import @(type_.name) -@[ end if]@ +@{ +module_name, type_name = get_importable_namespaced_type(type_) +}@ + import @(module_name) @[ end if]@ @{ diff --git a/rosidl_generator_py/resource/_msg_check_fields.py.em b/rosidl_generator_py/resource/_msg_check_fields.py.em index 50ad4128..ece3b37f 100644 --- a/rosidl_generator_py/resource/_msg_check_fields.py.em +++ b/rosidl_generator_py/resource/_msg_check_fields.py.em @@ -3,6 +3,9 @@ from rosidl_parser.definition import AbstractGenericString from rosidl_parser.definition import AbstractNestedType from rosidl_parser.definition import AbstractSequence from rosidl_parser.definition import Array +from rosidl_parser.definition import ACTION_FEEDBACK_SUFFIX +from rosidl_parser.definition import ACTION_GOAL_SUFFIX +from rosidl_parser.definition import ACTION_RESULT_SUFFIX from rosidl_parser.definition import BasicType from rosidl_parser.definition import BOOLEAN_TYPE from rosidl_parser.definition import INTEGER_TYPES @@ -11,8 +14,35 @@ from rosidl_parser.definition import FLOATING_POINT_TYPES from rosidl_parser.definition import SIGNED_INTEGER_TYPES from rosidl_parser.definition import UNSIGNED_INTEGER_TYPES from rosidl_parser.definition import NamespacedType +from rosidl_pycommon import convert_camel_case_to_lower_case_underscore from rosidl_generator_py.generate_py_impl import get_python_type from rosidl_generator_py.generate_py_impl import SPECIAL_NESTED_BASIC_TYPES + + +def get_importable_namespaced_type( + type_: NamespacedType, + action_goal_suffix: str = ACTION_GOAL_SUFFIX, + action_result_suffix: str = ACTION_RESULT_SUFFIX, + action_feedback_suffix: str = ACTION_FEEDBACK_SUFFIX, + camel_to_snake=convert_camel_case_to_lower_case_underscore, +) -> tuple[str, str]: + joined_type_namespaces = '.'.join(type_.namespaces) + if ( + type_.name.endswith(action_goal_suffix) or + type_.name.endswith(action_result_suffix) or + type_.name.endswith(action_feedback_suffix) + ): + action_name, _ = type_.name.rsplit('_', 1) + lower_case_name = camel_to_snake(action_name) + module_name = f'{joined_type_namespaces}._{lower_case_name}' + else: + module_name = joined_type_namespaces + return module_name, f'{module_name}.{type_.name}' +}@ +@{ +python_type = get_python_type(type_) +if isinstance(type_, NamespacedType): + _, python_type = get_importable_namespaced_type(type_) }@ if self._check_fields: @[ if isinstance(member.type, AbstractNestedType)]@ @@ -61,8 +91,8 @@ from rosidl_generator_py.generate_py_impl import SPECIAL_NESTED_BASIC_TYPES @{assert_msg_suffixes.insert(1, 'with length %d' % member.type.size)}@ @[ end if]@ @[ end if]@ - all(isinstance(v, @(get_python_type(type_))) for v in value) and -@{assert_msg_suffixes.append("and each value of type '%s'" % get_python_type(type_))}@ + all(isinstance(v, @(python_type)) for v in value) and +@{assert_msg_suffixes.append("and each value of type '%s'" % python_type)}@ @[ if isinstance(type_, BasicType) and type_.typename in SIGNED_INTEGER_TYPES]@ @{ nbits = int(type_.typename[3:]) @@ -106,7 +136,7 @@ bound = 1.7976931348623157e+308 "The '@(member.name)' field must be string value " \ 'not longer than @(type_.maximum_size)' @[ elif isinstance(type_, NamespacedType)]@ - isinstance(value, @(type_.name)), \ + isinstance(value, @(python_type)), \ "The '@(member.name)' field must be a sub message of type '@(type_.name)'" @[ elif isinstance(type_, BasicType) and type_.typename == 'octet']@ (isinstance(value, (bytes, bytearray, memoryview)) and diff --git a/rosidl_generator_py/test/test_template_imports.py b/rosidl_generator_py/test/test_template_imports.py new file mode 100644 index 00000000..23ebeac7 --- /dev/null +++ b/rosidl_generator_py/test/test_template_imports.py @@ -0,0 +1,123 @@ +# Copyright 2026 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import pathlib +import sys +from types import ModuleType + +from pytest import CaptureFixture +from rosidl_parser.definition import Array +from rosidl_parser.definition import IdlContent +from rosidl_parser.definition import Member +from rosidl_parser.definition import Message +from rosidl_parser.definition import NamespacedType +from rosidl_parser.definition import Structure +from rosidl_parser.definition import UnboundedSequence +from rosidl_pycommon import convert_camel_case_to_lower_case_underscore +from rosidl_pycommon import expand_template + +PACKAGE_NAME = 'rosidl_generator_py' +RESOURCE_DIR = pathlib.Path(__file__).parents[1] / 'resource' + + +try: + import rpyutils # noqa: F401 +except ImportError: + rpyutils = ModuleType('rpyutils') + + def _add_dll_directories_from_env( + _: str, + ) -> contextlib.AbstractContextManager[None]: + return contextlib.nullcontext() + + rpyutils.add_dll_directories_from_env = _add_dll_directories_from_env + sys.modules['rpyutils'] = rpyutils + + +def _render_message( + tmp_path: pathlib.Path, + capsys: CaptureFixture[str], + message_name: str, + members: list[Member], +) -> str: + content = IdlContent() + content.elements.append(Message( + Structure(NamespacedType([PACKAGE_NAME, 'msg'], message_name), members) + )) + + output_filename = ( + f'_{convert_camel_case_to_lower_case_underscore(message_name)}.py' + ) + output_path = tmp_path / output_filename + with capsys.disabled(): + expand_template( + str(RESOURCE_DIR / '_idl.py.em'), + { + 'package_name': PACKAGE_NAME, + 'interface_path': pathlib.Path('msg') / f'{message_name}.msg', + 'content': content, + }, + str(output_path), + ) + return output_path.read_text(encoding='utf-8') + + +def test_namespaced_field_imports_are_absolute( + tmp_path: pathlib.Path, + capsys: CaptureFixture[str], +) -> None: + duration_type = NamespacedType(['builtin_interfaces', 'msg'], 'Duration') + source = _render_message( + tmp_path, + capsys, + 'Duration', + [Member(duration_type, 'data')], + ) + + assert 'from builtin_interfaces.msg import Duration' not in source + assert 'import builtin_interfaces.msg' in source + assert ( + 'if builtin_interfaces.msg.Duration._TYPE_SUPPORT is None:' + ) in source + assert ( + 'self.data = data if data is not None else ' + 'builtin_interfaces.msg.Duration()' + ) in source + assert 'isinstance(value, builtin_interfaces.msg.Duration)' in source + + +def test_namespaced_array_and_sequence_imports_are_absolute( + tmp_path: pathlib.Path, + capsys: CaptureFixture[str], +) -> None: + duration_type = NamespacedType(['builtin_interfaces', 'msg'], 'Duration') + source = _render_message( + tmp_path, + capsys, + 'DurationArraySequence', + [ + Member(Array(duration_type, 2), 'array_data'), + Member(UnboundedSequence(duration_type), 'sequence_data'), + ], + ) + + assert 'from builtin_interfaces.msg import Duration' not in source + assert 'import builtin_interfaces.msg' in source + assert ( + '[builtin_interfaces.msg.Duration() for x in range(2)]' + ) in source + assert ( + 'all(isinstance(v, builtin_interfaces.msg.Duration) for v in value)' + ) in source