diff --git a/scapy/supersocket.py b/scapy/supersocket.py index 5ffd7a30f63..9c5b74a3b52 100644 --- a/scapy/supersocket.py +++ b/scapy/supersocket.py @@ -44,14 +44,21 @@ Optional, Tuple, Type, + TypeVar, cast, + TYPE_CHECKING, ) +from scapy.compat import Self + +if TYPE_CHECKING: + from scapy.ansmachine import AnsweringMachine + # Utils class _SuperSocket_metaclass(type): - desc = None # type: Optional[str] + desc = None # type: Optional[str] def __repr__(self): # type: () -> str @@ -82,10 +89,13 @@ class tpacket_auxdata(ctypes.Structure): # SuperSocket +_T = TypeVar("_T", Packet, PacketList) + + class SuperSocket(metaclass=_SuperSocket_metaclass): closed = False # type: bool nonblocking_socket = False # type: bool - auxdata_available = False # type: bool + auxdata_available = False # type: bool def __init__(self, family=socket.AF_INET, # type: int @@ -271,19 +281,17 @@ def tshark(self, *args, **kargs): from scapy import sendrecv sendrecv.tshark(opened_socket=self, *args, **kargs) - # TODO: use 'scapy.ansmachine.AnsweringMachine' when typed def am(self, - cls, # type: Type[Any] - *args, # type: Any + cls, # type: Type[AnsweringMachine[_T]] **kwargs # type: Any ): - # type: (...) -> Any + # type: (...) -> AnsweringMachine[_T] """ Creates an AnsweringMachine associated with this socket. :param cls: A subclass of AnsweringMachine to instantiate """ - return cls(*args, opened_socket=self, socket=self, **kwargs) + return cls(opened_socket=self, socket=self, **kwargs) @staticmethod def select(sockets, remain=conf.recv_poll_rate): @@ -295,6 +303,7 @@ def select(sockets, remain=conf.recv_poll_rate): :returns: an array of sockets that were selected and the function to be called next to get the packets (i.g. recv) """ + inp = [] # type: List[SuperSocket] try: inp, _, _ = select(sockets, [], [], remain) except (IOError, select_error) as exc: @@ -309,7 +318,7 @@ def __del__(self): self.close() def __enter__(self): - # type: () -> SuperSocket + # type: () -> Self return self def __exit__(self, exc_type, exc_value, traceback): @@ -627,6 +636,7 @@ def _iter(obj=cast(SndRcvList, obj)): s.time = s.sent_time yield s yield r + self.iter = _iter() elif isinstance(obj, (list, PacketList)): if isinstance(obj[0], bytes):