diff --git a/modules/cassandra/testcontainers/cassandra/__init__.py b/modules/cassandra/testcontainers/cassandra/__init__.py index f515aff1..2770451f 100644 --- a/modules/cassandra/testcontainers/cassandra/__init__.py +++ b/modules/cassandra/testcontainers/cassandra/__init__.py @@ -39,7 +39,9 @@ class CassandraContainer(DockerContainer): CQL_PORT = 9042 DEFAULT_LOCAL_DATACENTER = "datacenter1" - def __init__(self, image: str = "cassandra:latest", **kwargs) -> None: + def __init__( + self, image: str = "cassandra:latest", wait_strategy_check_string: str = "Startup complete", **kwargs + ) -> None: super().__init__(image=image, **kwargs) self.with_exposed_ports(self.CQL_PORT) self.with_env("JVM_OPTS", "-Dcassandra.skip_wait_for_gossip_to_settle=0 -Dcassandra.initial_token=0") @@ -47,7 +49,7 @@ def __init__(self, image: str = "cassandra:latest", **kwargs) -> None: self.with_env("MAX_HEAP_SIZE", "1024M") self.with_env("CASSANDRA_ENDPOINT_SNITCH", "GossipingPropertyFileSnitch") self.with_env("CASSANDRA_DC", self.DEFAULT_LOCAL_DATACENTER) - self.waiting_for(LogMessageWaitStrategy("Startup complete")) + self.waiting_for(LogMessageWaitStrategy(wait_strategy_check_string)) def get_contact_points(self) -> list[tuple[str, int]]: return [(self.get_container_host_ip(), int(self.get_exposed_port(self.CQL_PORT)))] diff --git a/modules/kafka/testcontainers/kafka/__init__.py b/modules/kafka/testcontainers/kafka/__init__.py index 55683f25..8eb4f718 100644 --- a/modules/kafka/testcontainers/kafka/__init__.py +++ b/modules/kafka/testcontainers/kafka/__init__.py @@ -55,12 +55,18 @@ class KafkaContainer(DockerContainer): TC_START_SCRIPT = "/tc-start.sh" MIN_KRAFT_TAG = "7.0.0" - def __init__(self, image: str = "confluentinc/cp-kafka:7.6.0", port: int = 9093, **kwargs) -> None: + def __init__( + self, + image: str = "confluentinc/cp-kafka:7.6.0", + port: int = 9093, + wait_strategy_check_string: str = r".*\[KafkaServer id=\d+\] started.*", + **kwargs, + ) -> None: raise_for_deprecated_parameter(kwargs, "port_to_expose", "port") super().__init__(image, **kwargs) self.port = port self.kraft_enabled = False - self.wait_for: re.Pattern[str] = re.compile(r".*\[KafkaServer id=\d+\] started.*") + self.wait_for: re.Pattern[str] = re.compile(wait_strategy_check_string) self.boot_command = "" self.cluster_id = "MkU3OEVBNTcwNTJENDM2Qk" self.listeners = f"PLAINTEXT://0.0.0.0:{self.port},BROKER://0.0.0.0:9092" diff --git a/modules/mysql/testcontainers/mysql/__init__.py b/modules/mysql/testcontainers/mysql/__init__.py index 51026d7f..0dba8d59 100644 --- a/modules/mysql/testcontainers/mysql/__init__.py +++ b/modules/mysql/testcontainers/mysql/__init__.py @@ -71,6 +71,7 @@ def __init__( dbname: Optional[str] = None, port: int = 3306, seed: Optional[str] = None, + wait_strategy_check_string: str = r".*: ready for connections.*: ready for connections.*", **kwargs, ) -> None: if dialect is not None and dialect.startswith("mysql+"): @@ -96,6 +97,7 @@ def __init__( if self.username == "root": self.root_password = self.password self.seed = seed + self.wait_strategy_check_string = wait_strategy_check_string def _configure(self) -> None: self.with_env("MYSQL_ROOT_PASSWORD", self.root_password) @@ -107,7 +109,7 @@ def _configure(self) -> None: def _connect(self) -> None: wait_strategy = LogMessageWaitStrategy( - re.compile(r".*: ready for connections.*: ready for connections.*", flags=re.DOTALL | re.MULTILINE), + re.compile(self.wait_strategy_check_string, flags=re.DOTALL | re.MULTILINE), ) wait_strategy.wait_until_ready(self) diff --git a/modules/trino/testcontainers/trino/__init__.py b/modules/trino/testcontainers/trino/__init__.py index ce180160..4532260b 100644 --- a/modules/trino/testcontainers/trino/__init__.py +++ b/modules/trino/testcontainers/trino/__init__.py @@ -24,6 +24,7 @@ def __init__( user: str = "test", port: int = 8080, container_start_timeout: int = 30, + wait_strategy_check_string: str = ".*======== SERVER STARTED ========.*", **kwargs, ): super().__init__(image=image, **kwargs) @@ -31,7 +32,7 @@ def __init__( self.port = port self.with_exposed_ports(self.port) self.waiting_for( - LogMessageWaitStrategy(re.compile(".*======== SERVER STARTED ========.*", re.MULTILINE)) + LogMessageWaitStrategy(re.compile(wait_strategy_check_string, re.MULTILINE)) .with_poll_interval(c.sleep_time) .with_startup_timeout(container_start_timeout) )