diff --git a/LICENSE b/LICENSE index b4e3f9a..2730aba 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2025 MrMaydo +Copyright (c) 2025 Zbigniew Brzezicki Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/header_generator.py b/src/header_generator.py new file mode 100644 index 0000000..416b478 --- /dev/null +++ b/src/header_generator.py @@ -0,0 +1,8 @@ +import re + + +def set_package(package: str) -> str: + PACKAGE_REGEX = r"^(?:[a-z_][a-z0-9_]*)(?:\.(?:[a-z_][a-z0-9_]*))*$" + if not re.match(PACKAGE_REGEX, package): + raise ValueError(f"Invalid package: '{package}'") + return f"package {package}" diff --git a/src/method_generator.py b/src/method_generator.py index 9073342..6e68abb 100644 --- a/src/method_generator.py +++ b/src/method_generator.py @@ -25,6 +25,7 @@ indent_lvl1 = " " indent_lvl2 = indent_lvl1 * 2 indent_lvl3 = indent_lvl1 * 3 +return_indent = " " @dataclass @@ -75,7 +76,7 @@ def generate_getter(field: Field) -> str: _validate_java_identifier(field_name) - getter_name = "get" + field_name[0].upper() + field_name[1:] + getter_name = _get_getter_name(field) getter = [ "", @@ -94,7 +95,7 @@ def generate_setter(field: Field) -> str: _validate_java_identifier(field_name) - setter_name = "set" + field_name[0].upper() + field_name[1:] + setter_name = _get_setter_name(field_name) setter = [ "", @@ -106,26 +107,76 @@ def generate_setter(field: Field) -> str: return setter +def generate_equals(class_name: str, fields: List[Field]) -> str: + equals = [ + "", + f"{indent_lvl1}@Override", + f"{indent_lvl1}public boolean equals(Object obj) {{", + + f"{indent_lvl2}if (this == obj)", + f"{indent_lvl3}return true;", + + f"{indent_lvl2}if (!(obj instanceof {class_name}))", + f"{indent_lvl3}return false;", + + f"{indent_lvl2}{class_name} that = ({class_name}) obj;" + ] + + for i, field in enumerate(fields): + _validate_java_identifier(field.name) + getter_name = _get_getter_name(field) + semicolon = ";" if i == (len(fields) - 1) else "" + if i == 0: + equals.append(f"{indent_lvl2}return Objects.equals({getter_name}(), that.{getter_name}()){semicolon}") + else: + equals.append(f"{indent_lvl2}{return_indent}&& Objects.equals({getter_name}(), that.{getter_name}()){semicolon}") + equals.append(f"{indent_lvl1}}}") + + return "\n".join(equals) + + def generate_hash_code(fields: List[Field]) -> str: hash_code = [ "", f"{indent_lvl1}@Override", - f"{indent_lvl1}public int hashCode() {{", - f"{indent_lvl2}return Objects.hash(" + f"{indent_lvl1}public int hashCode() {{" ] for i, field in enumerate(fields): _validate_java_identifier(field.name) - getter_name = "get" + field.name[0].upper() + field.name[1:] - comma = "," if i < (len(fields) - 1) else "" - hash_code.append(f"{indent_lvl2} {getter_name}(){comma}") + hash_code.extend(_get_return_hash_lines(fields, index=i)) - hash_code.append(f"{indent_lvl2});") hash_code.append(f"{indent_lvl1}}}") return "\n".join(hash_code) +def _get_return_hash_lines(fields, index) -> List[str]: + field = fields[index] + getter_name = _get_getter_name(field) + comma = "," if index < (len(fields) - 1) else "" + another_hash_line = f"{indent_lvl2}{return_indent}{getter_name}(){comma}" + if len(fields) == 1: + return [f"{indent_lvl2}return Objects.hash({getter_name}());"] + if len(fields) > 1 and index == 0: + return [f"{indent_lvl2}return Objects.hash(", + f"{another_hash_line}"] + if len(fields) > 1 and index == (len(fields) - 1): + return [f"{another_hash_line}", + f"{indent_lvl2});"] + return [another_hash_line] + + +def _get_getter_name(field): + getter_name = "get" + field.name[0].upper() + field.name[1:] + return getter_name + + +def _get_setter_name(field_name): + setter_name = "set" + field_name[0].upper() + field_name[1:] + return setter_name + + def _validate_java_identifier(name: str) -> None: if not name: raise ValueError("Field name cannot be empty") diff --git a/tests/test_header_generator.py b/tests/test_header_generator.py new file mode 100644 index 0000000..5347846 --- /dev/null +++ b/tests/test_header_generator.py @@ -0,0 +1,23 @@ +import pytest + +from src.header_generator import set_package + + +def test_set_package(): + example = "org.example.hyphenated_name" + expected = f"package {example}" + assert set_package(example) == expected + + example_2 = "com.example._123name" + expected_2 = f"package {example_2}" + assert set_package(example_2) == expected_2 + + +def test_set_package_illegal_name(): + example = "org.example.hyphenated-name" + with pytest.raises(ValueError): + set_package(example) + + example_2 = "com.example.123name" + with pytest.raises(ValueError): + set_package(example_2) diff --git a/tests/test_method_generator.py b/tests/test_method_generator.py index 3fa3f58..bcac62b 100644 --- a/tests/test_method_generator.py +++ b/tests/test_method_generator.py @@ -155,8 +155,62 @@ def test_generate_hash_code(): assert generate_hash_code(attributes) == expected +def test_generate_hash_code_one_field(): + attr = field_exampleAttribute_int + expected = """ + @Override + public int hashCode() { + return Objects.hash(getExampleAttribute()); + }""" + assert generate_hash_code([attr]) == expected + + def test_generate_hash_code_invalid_name(): for name in illegal_names: attr = Field(name=name, type="String") with pytest.raises(ValueError): generate_hash_code([attr]) + + +def test_generate_equals(): + class_name = "MyClass" + attr1 = field_exampleAttribute_int + attr2 = field_someName_String + attr3 = field_customData_CustomObject + attributes = [attr1, attr2, attr3] + expected = """ + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (!(obj instanceof MyClass)) + return false; + MyClass that = (MyClass) obj; + return Objects.equals(getExampleAttribute(), that.getExampleAttribute()) + && Objects.equals(getSomeName(), that.getSomeName()) + && Objects.equals(getCustomData(), that.getCustomData()); + }""" + assert generate_equals(class_name, attributes) == expected + + +def test_generate_equals_one_field(): + class_name = "AnotherClass" + attr = field_someName_String + expected = """ + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (!(obj instanceof AnotherClass)) + return false; + AnotherClass that = (AnotherClass) obj; + return Objects.equals(getSomeName(), that.getSomeName()); + }""" + assert generate_equals(class_name, [attr]) == expected + + +def test_generate_equals_invalid_name(): + for name in illegal_names: + attr = Field(name=name, type="String") + with pytest.raises(ValueError): + generate_equals("MyClass", [attr])