Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rosidl_generator_py/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ if(BUILD_TESTING)
)

ament_add_pytest_test(test_cli_extension test/test_cli_extension.py)
ament_add_pytest_test(test_template_imports_py test/test_template_imports.py)

ament_add_pytest_test(test_property_py test/test_property.py
APPEND_ENV "PYTHONPATH=${pythonpath}"
Expand Down
71 changes: 38 additions & 33 deletions rosidl_generator_py/resource/_msg.py.em
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ from rosidl_parser.definition import NamespacedType
from rosidl_parser.definition import SIGNED_INTEGER_TYPES
from rosidl_parser.definition import UnboundedSequence
from rosidl_parser.definition import UNSIGNED_INTEGER_TYPES


def get_importable_namespaced_type(
type_: NamespacedType,
action_goal_suffix: str = ACTION_GOAL_SUFFIX,
action_result_suffix: str = ACTION_RESULT_SUFFIX,
action_feedback_suffix: str = ACTION_FEEDBACK_SUFFIX,
camel_to_snake=convert_camel_case_to_lower_case_underscore,
) -> tuple[str, str]:
joined_type_namespaces = '.'.join(type_.namespaces)
if (
type_.name.endswith(action_goal_suffix) or
type_.name.endswith(action_result_suffix) or
type_.name.endswith(action_feedback_suffix)
):
action_name, _ = type_.name.rsplit('_', 1)
lower_case_name = camel_to_snake(action_name)
module_name = f'{joined_type_namespaces}._{lower_case_name}'
else:
module_name = joined_type_namespaces
return module_name, f'{module_name}.{type_.name}'
}@
@{
import_type_checking = False
Expand Down Expand Up @@ -182,22 +203,13 @@ for member in message.structure.members:
type_.name.endswith(SERVICE_REQUEST_MESSAGE_SUFFIX)
):
continue
if (
type_.name.endswith(ACTION_GOAL_SUFFIX) or
type_.name.endswith(ACTION_RESULT_SUFFIX) or
type_.name.endswith(ACTION_FEEDBACK_SUFFIX)
):
action_name, suffix = type_.name.rsplit('_', 1)
typename = (*type_.namespaces, action_name, action_name + '.' + suffix)
else:
typename = (*type_.namespaces, type_.name, type_.name)
importable_typesupports.add(typename)
importable_typesupports.add(get_importable_namespaced_type(type_))
}@
@[for typename in sorted(importable_typesupports)]@
@[for module_name, type_name in sorted(importable_typesupports)]@

from @('.'.join(typename[:-2])) import @(typename[-2])
if @(typename[-1])._TYPE_SUPPORT is None:
@(typename[-1]).__import_type_support__()
import @(module_name)
if @(type_name)._TYPE_SUPPORT is None:
@(type_name).__import_type_support__()
@[end for]@

@@classmethod
Expand Down Expand Up @@ -370,15 +382,10 @@ if isinstance(type_, AbstractNestedType):
self.@(member.name) = @(member.name) if @(member.name) is not None else @(message.structure.namespaced_type.name).@(member.name.upper())__DEFAULT
@[ else]@
@[ if isinstance(type_, NamespacedType) and not isinstance(member.type, AbstractSequence)]@
@[ if (
type_.name.endswith(ACTION_GOAL_SUFFIX) or
type_.name.endswith(ACTION_RESULT_SUFFIX) or
type_.name.endswith(ACTION_FEEDBACK_SUFFIX)
)]@
from @('.'.join(type_.namespaces))._@(convert_camel_case_to_lower_case_underscore(type_.name.rsplit('_', 1)[0])) import @(type_.name)
@[ else]@
from @('.'.join(type_.namespaces)) import @(type_.name)
@[ end if]@
@{
module_name, type_name = get_importable_namespaced_type(type_)
}@
import @(module_name)
@[ end if]@
@[ if isinstance(member.type, Array)]@
@[ if isinstance(type_, BasicType) and type_.typename == 'octet']@
Expand All @@ -392,7 +399,8 @@ if isinstance(type_, AbstractNestedType):
else:
self.@(member.name) = @(member.name)
@[ else]@
self.@(member.name) = @(member.name) if @(member.name) is not None else [@(get_python_type(type_))() for x in range(@(member.type.size))]
@{default_type = type_name if isinstance(type_, NamespacedType) else get_python_type(type_)}@
self.@(member.name) = @(member.name) if @(member.name) is not None else [@(default_type)() for x in range(@(member.type.size))]
@[ end if]@
@[ end if]@
@[ elif isinstance(member.type, AbstractSequence)]@
Expand All @@ -405,6 +413,8 @@ if isinstance(type_, AbstractNestedType):
self.@(member.name) = @(member.name) if @(member.name) is not None else bytes([0])
@[ elif isinstance(type_, BasicType) and type_.typename in CHARACTER_TYPES]@
self.@(member.name) = @(member.name) if @(member.name) is not None else chr(0)
@[ elif isinstance(type_, NamespacedType)]@
self.@(member.name) = @(member.name) if @(member.name) is not None else @(type_name)()
@[ else]@
self.@(member.name) = @(member.name) if @(member.name) is not None else @(get_python_type(type_))()
@[ end if]@
Expand Down Expand Up @@ -504,15 +514,10 @@ if isinstance(member.type, (Array, AbstractSequence)):
@[ end if]@
@[ end if]@
@[ if isinstance(type_, NamespacedType)]@
@[ if (
type_.name.endswith(ACTION_GOAL_SUFFIX) or
type_.name.endswith(ACTION_RESULT_SUFFIX) or
type_.name.endswith(ACTION_FEEDBACK_SUFFIX)
)]@
from @('.'.join(type_.namespaces))._@(convert_camel_case_to_lower_case_underscore(type_.name.rsplit('_', 1)[0])) import @(type_.name)
@[ else]@
from @('.'.join(type_.namespaces)) import @(type_.name)
@[ end if]@
@{
module_name, type_name = get_importable_namespaced_type(type_)
}@
import @(module_name)
@[ end if]@

@{
Expand Down
36 changes: 33 additions & 3 deletions rosidl_generator_py/resource/_msg_check_fields.py.em
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ from rosidl_parser.definition import AbstractGenericString
from rosidl_parser.definition import AbstractNestedType
from rosidl_parser.definition import AbstractSequence
from rosidl_parser.definition import Array
from rosidl_parser.definition import ACTION_FEEDBACK_SUFFIX
from rosidl_parser.definition import ACTION_GOAL_SUFFIX
from rosidl_parser.definition import ACTION_RESULT_SUFFIX
from rosidl_parser.definition import BasicType
from rosidl_parser.definition import BOOLEAN_TYPE
from rosidl_parser.definition import INTEGER_TYPES
Expand All @@ -11,8 +14,35 @@ from rosidl_parser.definition import FLOATING_POINT_TYPES
from rosidl_parser.definition import SIGNED_INTEGER_TYPES
from rosidl_parser.definition import UNSIGNED_INTEGER_TYPES
from rosidl_parser.definition import NamespacedType
from rosidl_pycommon import convert_camel_case_to_lower_case_underscore
from rosidl_generator_py.generate_py_impl import get_python_type
from rosidl_generator_py.generate_py_impl import SPECIAL_NESTED_BASIC_TYPES


def get_importable_namespaced_type(
type_: NamespacedType,
action_goal_suffix: str = ACTION_GOAL_SUFFIX,
action_result_suffix: str = ACTION_RESULT_SUFFIX,
action_feedback_suffix: str = ACTION_FEEDBACK_SUFFIX,
camel_to_snake=convert_camel_case_to_lower_case_underscore,
) -> tuple[str, str]:
joined_type_namespaces = '.'.join(type_.namespaces)
if (
type_.name.endswith(action_goal_suffix) or
type_.name.endswith(action_result_suffix) or
type_.name.endswith(action_feedback_suffix)
):
action_name, _ = type_.name.rsplit('_', 1)
lower_case_name = camel_to_snake(action_name)
module_name = f'{joined_type_namespaces}._{lower_case_name}'
else:
module_name = joined_type_namespaces
return module_name, f'{module_name}.{type_.name}'
}@
@{
python_type = get_python_type(type_)
if isinstance(type_, NamespacedType):
_, python_type = get_importable_namespaced_type(type_)
}@
if self._check_fields:
@[ if isinstance(member.type, AbstractNestedType)]@
Expand Down Expand Up @@ -61,8 +91,8 @@ from rosidl_generator_py.generate_py_impl import SPECIAL_NESTED_BASIC_TYPES
@{assert_msg_suffixes.insert(1, 'with length %d' % member.type.size)}@
@[ end if]@
@[ end if]@
all(isinstance(v, @(get_python_type(type_))) for v in value) and
@{assert_msg_suffixes.append("and each value of type '%s'" % get_python_type(type_))}@
all(isinstance(v, @(python_type)) for v in value) and
@{assert_msg_suffixes.append("and each value of type '%s'" % python_type)}@
@[ if isinstance(type_, BasicType) and type_.typename in SIGNED_INTEGER_TYPES]@
@{
nbits = int(type_.typename[3:])
Expand Down Expand Up @@ -106,7 +136,7 @@ bound = 1.7976931348623157e+308
"The '@(member.name)' field must be string value " \
'not longer than @(type_.maximum_size)'
@[ elif isinstance(type_, NamespacedType)]@
isinstance(value, @(type_.name)), \
isinstance(value, @(python_type)), \
"The '@(member.name)' field must be a sub message of type '@(type_.name)'"
@[ elif isinstance(type_, BasicType) and type_.typename == 'octet']@
(isinstance(value, (bytes, bytearray, memoryview)) and
Expand Down
123 changes: 123 additions & 0 deletions rosidl_generator_py/test/test_template_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2026 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import pathlib
import sys
from types import ModuleType

from pytest import CaptureFixture
from rosidl_parser.definition import Array
from rosidl_parser.definition import IdlContent
from rosidl_parser.definition import Member
from rosidl_parser.definition import Message
from rosidl_parser.definition import NamespacedType
from rosidl_parser.definition import Structure
from rosidl_parser.definition import UnboundedSequence
from rosidl_pycommon import convert_camel_case_to_lower_case_underscore
from rosidl_pycommon import expand_template

PACKAGE_NAME = 'rosidl_generator_py'
RESOURCE_DIR = pathlib.Path(__file__).parents[1] / 'resource'


try:
import rpyutils # noqa: F401
except ImportError:
rpyutils = ModuleType('rpyutils')

def _add_dll_directories_from_env(
_: str,
) -> contextlib.AbstractContextManager[None]:
return contextlib.nullcontext()

rpyutils.add_dll_directories_from_env = _add_dll_directories_from_env
sys.modules['rpyutils'] = rpyutils


def _render_message(
tmp_path: pathlib.Path,
capsys: CaptureFixture[str],
message_name: str,
members: list[Member],
) -> str:
content = IdlContent()
content.elements.append(Message(
Structure(NamespacedType([PACKAGE_NAME, 'msg'], message_name), members)
))

output_filename = (
f'_{convert_camel_case_to_lower_case_underscore(message_name)}.py'
)
output_path = tmp_path / output_filename
with capsys.disabled():
expand_template(
str(RESOURCE_DIR / '_idl.py.em'),
{
'package_name': PACKAGE_NAME,
'interface_path': pathlib.Path('msg') / f'{message_name}.msg',
'content': content,
},
str(output_path),
)
return output_path.read_text(encoding='utf-8')


def test_namespaced_field_imports_are_absolute(
tmp_path: pathlib.Path,
capsys: CaptureFixture[str],
) -> None:
duration_type = NamespacedType(['builtin_interfaces', 'msg'], 'Duration')
source = _render_message(
tmp_path,
capsys,
'Duration',
[Member(duration_type, 'data')],
)

assert 'from builtin_interfaces.msg import Duration' not in source
assert 'import builtin_interfaces.msg' in source
assert (
'if builtin_interfaces.msg.Duration._TYPE_SUPPORT is None:'
) in source
assert (
'self.data = data if data is not None else '
'builtin_interfaces.msg.Duration()'
) in source
assert 'isinstance(value, builtin_interfaces.msg.Duration)' in source


def test_namespaced_array_and_sequence_imports_are_absolute(
tmp_path: pathlib.Path,
capsys: CaptureFixture[str],
) -> None:
duration_type = NamespacedType(['builtin_interfaces', 'msg'], 'Duration')
source = _render_message(
tmp_path,
capsys,
'DurationArraySequence',
[
Member(Array(duration_type, 2), 'array_data'),
Member(UnboundedSequence(duration_type), 'sequence_data'),
],
)

assert 'from builtin_interfaces.msg import Duration' not in source
assert 'import builtin_interfaces.msg' in source
assert (
'[builtin_interfaces.msg.Duration() for x in range(2)]'
) in source
assert (
'all(isinstance(v, builtin_interfaces.msg.Duration) for v in value)'
) in source