diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py index fa1f0f89..18a8256c 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py @@ -53,7 +53,20 @@ def _get_base_unet(self, unet): if hasattr(attr, 'config') and hasattr(attr.config, 'addition_embed_type'): return attr return unet - + + def _resolve_num_ip_layers(self) -> Optional[int]: + """Resolve IP-Adapter layer count from known wrapper layouts.""" + direct_layers = getattr(self.unet, "num_ip_layers", None) + if isinstance(direct_layers, int) and direct_layers > 0: + return direct_layers + + ip_wrap = getattr(self.unet, "ipadapter_wrapper", None) + wrapped_layers = getattr(ip_wrap, "num_ip_layers", None) + if isinstance(wrapped_layers, int) and wrapped_layers > 0: + return wrapped_layers + + return None + def _test_added_cond_support(self): """Test if this SDXL model supports added_cond_kwargs""" try: @@ -61,6 +74,23 @@ def _test_added_cond_support(self): sample = torch.randn(1, 4, 8, 8, device='cuda', dtype=torch.float16) timestep = torch.tensor([0.5], device='cuda', dtype=torch.float32) encoder_hidden_states = torch.randn(1, 77, 2048, device='cuda', dtype=torch.float16) + probe_args = [sample, timestep, encoder_hidden_states] + + if getattr(self.unet, "use_ipadapter", False): + num_ip_layers = self._resolve_num_ip_layers() + if num_ip_layers is None: + direct_layers = getattr(self.unet, "num_ip_layers", None) + wrapped_layers = getattr(getattr(self.unet, "ipadapter_wrapper", None), "num_ip_layers", None) + logger.error( + "SDXL probe: use_ipadapter=True but num_ip_layers is invalid: direct=%s wrapped=%s", + direct_layers, + wrapped_layers, + ) + return False + + probe_args.append( + torch.ones(num_ip_layers, device=sample.device, dtype=torch.float32) + ) # Test with added_cond_kwargs test_added_cond = { @@ -69,7 +99,7 @@ def _test_added_cond_support(self): } with torch.no_grad(): - _ = self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=test_added_cond) + _ = self.unet(*probe_args, added_cond_kwargs=test_added_cond) logger.info("SDXL model supports added_cond_kwargs") return True