Skip to content

Params4bit.__getattr__ breaks torch.compile - use @property instead #1904

@kbabiuchx

Description

@kbabiuchx

System Info

PR #1866 added __getattr__ to Params4bit for FSDP state_dict support. This works fine for FSDP, but it breaks torch.compile.

Params4bit is a torch.Tensor subclass. PyTorch's Dynamo (the compiler frontend) doesn't know how to trace tensor subclasses that define __getattr__, so it creates graph breaks whenever it encounters attribute access on such objects. With activation checkpointing, these graph breaks multiply across layers, resulting in many more subgraphs than necessary and significant compilation overhead.

We noticed this when running torch.compile on a QLoRA fine-tuning workload (LLaMA 70B, HuggingFace, activation checkpointing). With __getattr__ present, we saw significant performance degradation caused by graph breaks. Removing __getattr__ (or replacing with @property as proposed below) restores expected performance.

Reproduction

I put together a minimal repro using torch._dynamo.explain() on a small model with Linear4bit and activation checkpointing. With __getattr__ present it showed graph breaks; after removing it, they were gone. The script is rough, so I'd appreciate it if you could verify this with a more representative test — or suggest one if you have something better suited.

Expected behavior

Replace __getattr__ + _QUANT_STATE_ATTR_MAP with @property descriptors. Properties are resolved at the class level through Python's descriptor protocol — Dynamo handles them fine, no graph breaks. FSDP still works because getattr(weight, "absmax") resolves the same way. Example:

    @property
    def absmax(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None:
            return qs.absmax
        raise AttributeError("'Params4bit' object has no attribute 'absmax'")

    @property
    def code(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None:
            return qs.code
        raise AttributeError("'Params4bit' object has no attribute 'code'")

    @property
    def quant_map(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None:
            return qs.code
        raise AttributeError("'Params4bit' object has no attribute 'quant_map'")

    @property
    def offset(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None:
            return qs.offset
        raise AttributeError("'Params4bit' object has no attribute 'offset'")

    @property
    def state2(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None:
            return qs.state2
        raise AttributeError("'Params4bit' object has no attribute 'state2'")

    @property
    def nested_absmax(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None and qs.state2 is not None:
            return qs.state2.absmax
        raise AttributeError("'Params4bit' object has no attribute 'nested_absmax'")

    @property
    def nested_blocksize(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None and qs.state2 is not None:
            return qs.state2.blocksize
        raise AttributeError("'Params4bit' object has no attribute 'nested_blocksize'")

    @property
    def nested_quant_map(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None and qs.state2 is not None:
            return qs.state2.code
        raise AttributeError("'Params4bit' object has no attribute 'nested_quant_map'")

    @property
    def nested_dtype(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None and qs.state2 is not None:
            return qs.state2.dtype
        raise AttributeError("'Params4bit' object has no attribute 'nested_dtype'")

    @property
    def nested_offset(self):
        qs = self.__dict__.get("quant_state")
        if qs is not None:
            return qs.offset
        raise AttributeError("'Params4bit' object has no attribute 'nested_offset'")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions