Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: Python tests

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2025 MrMaydo
Copyright (c) 2025 Zbigniew Brzezicki

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
8 changes: 8 additions & 0 deletions src/header_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import re


def set_package(package: str) -> str:
PACKAGE_REGEX = r"^(?:[a-z_][a-z0-9_]*)(?:\.(?:[a-z_][a-z0-9_]*))*$"
if not re.match(PACKAGE_REGEX, package):
raise ValueError(f"Invalid package: '{package}'")
return f"package {package}"
67 changes: 59 additions & 8 deletions src/method_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
indent_lvl1 = " "
indent_lvl2 = indent_lvl1 * 2
indent_lvl3 = indent_lvl1 * 3
return_indent = " "


@dataclass
Expand Down Expand Up @@ -75,7 +76,7 @@ def generate_getter(field: Field) -> str:

_validate_java_identifier(field_name)

getter_name = "get" + field_name[0].upper() + field_name[1:]
getter_name = _get_getter_name(field)

getter = [
"",
Expand All @@ -94,7 +95,7 @@ def generate_setter(field: Field) -> str:

_validate_java_identifier(field_name)

setter_name = "set" + field_name[0].upper() + field_name[1:]
setter_name = _get_setter_name(field_name)

setter = [
"",
Expand All @@ -106,26 +107,76 @@ def generate_setter(field: Field) -> str:
return setter


def generate_equals(class_name: str, fields: List[Field]) -> str:
equals = [
"",
f"{indent_lvl1}@Override",
f"{indent_lvl1}public boolean equals(Object obj) {{",

f"{indent_lvl2}if (this == obj)",
f"{indent_lvl3}return true;",

f"{indent_lvl2}if (!(obj instanceof {class_name}))",
f"{indent_lvl3}return false;",

f"{indent_lvl2}{class_name} that = ({class_name}) obj;"
]

for i, field in enumerate(fields):
_validate_java_identifier(field.name)
getter_name = _get_getter_name(field)
semicolon = ";" if i == (len(fields) - 1) else ""
if i == 0:
equals.append(f"{indent_lvl2}return Objects.equals({getter_name}(), that.{getter_name}()){semicolon}")
else:
equals.append(f"{indent_lvl2}{return_indent}&& Objects.equals({getter_name}(), that.{getter_name}()){semicolon}")
equals.append(f"{indent_lvl1}}}")

return "\n".join(equals)


def generate_hash_code(fields: List[Field]) -> str:
hash_code = [
"",
f"{indent_lvl1}@Override",
f"{indent_lvl1}public int hashCode() {{",
f"{indent_lvl2}return Objects.hash("
f"{indent_lvl1}public int hashCode() {{"
]

for i, field in enumerate(fields):
_validate_java_identifier(field.name)
getter_name = "get" + field.name[0].upper() + field.name[1:]
comma = "," if i < (len(fields) - 1) else ""
hash_code.append(f"{indent_lvl2} {getter_name}(){comma}")
hash_code.extend(_get_return_hash_lines(fields, index=i))

hash_code.append(f"{indent_lvl2});")
hash_code.append(f"{indent_lvl1}}}")

return "\n".join(hash_code)


def _get_return_hash_lines(fields, index) -> List[str]:
field = fields[index]
getter_name = _get_getter_name(field)
comma = "," if index < (len(fields) - 1) else ""
another_hash_line = f"{indent_lvl2}{return_indent}{getter_name}(){comma}"
if len(fields) == 1:
return [f"{indent_lvl2}return Objects.hash({getter_name}());"]
if len(fields) > 1 and index == 0:
return [f"{indent_lvl2}return Objects.hash(",
f"{another_hash_line}"]
if len(fields) > 1 and index == (len(fields) - 1):
return [f"{another_hash_line}",
f"{indent_lvl2});"]
return [another_hash_line]


def _get_getter_name(field):
getter_name = "get" + field.name[0].upper() + field.name[1:]
return getter_name


def _get_setter_name(field_name):
setter_name = "set" + field_name[0].upper() + field_name[1:]
return setter_name


def _validate_java_identifier(name: str) -> None:
if not name:
raise ValueError("Field name cannot be empty")
Expand Down
23 changes: 23 additions & 0 deletions tests/test_header_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from src.header_generator import set_package


def test_set_package():
example = "org.example.hyphenated_name"
expected = f"package {example}"
assert set_package(example) == expected

example_2 = "com.example._123name"
expected_2 = f"package {example_2}"
assert set_package(example_2) == expected_2


def test_set_package_illegal_name():
example = "org.example.hyphenated-name"
with pytest.raises(ValueError):
set_package(example)

example_2 = "com.example.123name"
with pytest.raises(ValueError):
set_package(example_2)
54 changes: 54 additions & 0 deletions tests/test_method_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,62 @@ def test_generate_hash_code():
assert generate_hash_code(attributes) == expected


def test_generate_hash_code_one_field():
attr = field_exampleAttribute_int
expected = """
@Override
public int hashCode() {
return Objects.hash(getExampleAttribute());
}"""
assert generate_hash_code([attr]) == expected


def test_generate_hash_code_invalid_name():
for name in illegal_names:
attr = Field(name=name, type="String")
with pytest.raises(ValueError):
generate_hash_code([attr])


def test_generate_equals():
class_name = "MyClass"
attr1 = field_exampleAttribute_int
attr2 = field_someName_String
attr3 = field_customData_CustomObject
attributes = [attr1, attr2, attr3]
expected = """
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (!(obj instanceof MyClass))
return false;
MyClass that = (MyClass) obj;
return Objects.equals(getExampleAttribute(), that.getExampleAttribute())
&& Objects.equals(getSomeName(), that.getSomeName())
&& Objects.equals(getCustomData(), that.getCustomData());
}"""
assert generate_equals(class_name, attributes) == expected


def test_generate_equals_one_field():
class_name = "AnotherClass"
attr = field_someName_String
expected = """
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (!(obj instanceof AnotherClass))
return false;
AnotherClass that = (AnotherClass) obj;
return Objects.equals(getSomeName(), that.getSomeName());
}"""
assert generate_equals(class_name, [attr]) == expected


def test_generate_equals_invalid_name():
for name in illegal_names:
attr = Field(name=name, type="String")
with pytest.raises(ValueError):
generate_equals("MyClass", [attr])