From 65d8f60052b284ad031710a4f060d700a52ce3ea Mon Sep 17 00:00:00 2001 From: kramiusmaximus Date: Sun, 22 Feb 2026 18:48:24 +0300 Subject: [PATCH 1/3] added ip_adapter scale dummy for model probe --- .../export_wrappers/unet_sdxl_export.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py index fa1f0f89..eb47aa81 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py @@ -53,6 +53,46 @@ def _get_base_unet(self, unet): if hasattr(attr, 'config') and hasattr(attr.config, 'addition_embed_type'): return attr return unet + + def _resolve_num_ip_layers(self) -> Optional[int]: + """Resolve IP-Adapter layer count from known wrapper layouts.""" + direct_layers = getattr(self.unet, "num_ip_layers", None) + if isinstance(direct_layers, int) and direct_layers > 0: + return direct_layers + + ip_wrap = getattr(self.unet, "ipadapter_wrapper", None) + wrapped_layers = getattr(ip_wrap, "num_ip_layers", None) + if isinstance(wrapped_layers, int) and wrapped_layers > 0: + return wrapped_layers + + return None + + def _build_probe_args( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> Optional[List[torch.Tensor]]: + """Build positional args for SDXL support probing.""" + probe_args: List[torch.Tensor] = [sample, timestep, encoder_hidden_states] + if not getattr(self.unet, "use_ipadapter", False): + return probe_args + + num_ip_layers = self._resolve_num_ip_layers() + if num_ip_layers is None: + direct_layers = getattr(self.unet, "num_ip_layers", None) + wrapped_layers = getattr(getattr(self.unet, "ipadapter_wrapper", None), "num_ip_layers", None) + logger.error( + "SDXL probe: use_ipadapter=True but num_ip_layers is invalid: direct=%s wrapped=%s", + direct_layers, + wrapped_layers, + ) + return None + + probe_args.append( + torch.ones(num_ip_layers, device=sample.device, dtype=torch.float32) + ) + return probe_args def _test_added_cond_support(self): """Test if this SDXL model supports added_cond_kwargs""" @@ -67,9 +107,17 @@ def _test_added_cond_support(self): 'text_embeds': torch.randn(1, 1280, device='cuda', dtype=torch.float16), 'time_ids': torch.randn(1, 6, device='cuda', dtype=torch.float16) } + + # UnifiedExportWrapper with IP-Adapter enabled requires a positional + # ipadapter_scale tensor before kwargs. During this capability probe + # we provide a sample vector only for call-shape validation; real + # per-step values are supplied at runtime by IP-Adapter hooks. + probe_args = self._build_probe_args(sample, timestep, encoder_hidden_states) + if probe_args is None: + return False with torch.no_grad(): - _ = self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=test_added_cond) + _ = self.unet(*probe_args, added_cond_kwargs=test_added_cond) logger.info("SDXL model supports added_cond_kwargs") return True From c38d1b0fd8c541024ec30900c5676600b2bedc4f Mon Sep 17 00:00:00 2001 From: kramiusmaximus Date: Sun, 22 Feb 2026 20:08:51 +0300 Subject: [PATCH 2/3] refactor --- .../export_wrappers/unet_sdxl_export.py | 56 +++++++------------ 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py index eb47aa81..4c480b69 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py @@ -5,7 +5,7 @@ """ import torch -from typing import Dict, List, Optional, Tuple, Any, Union +from typing import Dict, Optional, Tuple, Any, Union from diffusers import UNet2DConditionModel from ....model_detection import ( detect_model, @@ -67,33 +67,6 @@ def _resolve_num_ip_layers(self) -> Optional[int]: return None - def _build_probe_args( - self, - sample: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - ) -> Optional[List[torch.Tensor]]: - """Build positional args for SDXL support probing.""" - probe_args: List[torch.Tensor] = [sample, timestep, encoder_hidden_states] - if not getattr(self.unet, "use_ipadapter", False): - return probe_args - - num_ip_layers = self._resolve_num_ip_layers() - if num_ip_layers is None: - direct_layers = getattr(self.unet, "num_ip_layers", None) - wrapped_layers = getattr(getattr(self.unet, "ipadapter_wrapper", None), "num_ip_layers", None) - logger.error( - "SDXL probe: use_ipadapter=True but num_ip_layers is invalid: direct=%s wrapped=%s", - direct_layers, - wrapped_layers, - ) - return None - - probe_args.append( - torch.ones(num_ip_layers, device=sample.device, dtype=torch.float32) - ) - return probe_args - def _test_added_cond_support(self): """Test if this SDXL model supports added_cond_kwargs""" try: @@ -101,20 +74,29 @@ def _test_added_cond_support(self): sample = torch.randn(1, 4, 8, 8, device='cuda', dtype=torch.float16) timestep = torch.tensor([0.5], device='cuda', dtype=torch.float32) encoder_hidden_states = torch.randn(1, 77, 2048, device='cuda', dtype=torch.float16) + probe_args = [sample, timestep, encoder_hidden_states] + + if getattr(self.unet, "use_ipadapter", False): + num_ip_layers = self._resolve_num_ip_layers() + if num_ip_layers is None: + direct_layers = getattr(self.unet, "num_ip_layers", None) + wrapped_layers = getattr(getattr(self.unet, "ipadapter_wrapper", None), "num_ip_layers", None) + logger.error( + "SDXL probe: use_ipadapter=True but num_ip_layers is invalid: direct=%s wrapped=%s", + direct_layers, + wrapped_layers, + ) + return False + + probe_args.append( + torch.ones(num_ip_layers, device=sample.device, dtype=torch.float32) + ) # Test with added_cond_kwargs test_added_cond = { 'text_embeds': torch.randn(1, 1280, device='cuda', dtype=torch.float16), 'time_ids': torch.randn(1, 6, device='cuda', dtype=torch.float16) } - - # UnifiedExportWrapper with IP-Adapter enabled requires a positional - # ipadapter_scale tensor before kwargs. During this capability probe - # we provide a sample vector only for call-shape validation; real - # per-step values are supplied at runtime by IP-Adapter hooks. - probe_args = self._build_probe_args(sample, timestep, encoder_hidden_states) - if probe_args is None: - return False with torch.no_grad(): _ = self.unet(*probe_args, added_cond_kwargs=test_added_cond) @@ -361,4 +343,4 @@ def get_sdxl_tensorrt_config(model_path: str, unet: UNet2DConditionModel) -> Dic conditioning_handler = SDXLConditioningHandler(config) config['conditioning_spec'] = conditioning_handler.get_conditioning_spec() - return config \ No newline at end of file + return config From f586769bc28e778adc7b7d2624f9d82bf8cf65d9 Mon Sep 17 00:00:00 2001 From: kramiusmaximus Date: Sun, 22 Feb 2026 20:25:28 +0300 Subject: [PATCH 3/3] removed unnecassary changes --- .../acceleration/tensorrt/export_wrappers/unet_sdxl_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py index 4c480b69..18a8256c 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py @@ -5,7 +5,7 @@ """ import torch -from typing import Dict, Optional, Tuple, Any, Union +from typing import Dict, List, Optional, Tuple, Any, Union from diffusers import UNet2DConditionModel from ....model_detection import ( detect_model, @@ -343,4 +343,4 @@ def get_sdxl_tensorrt_config(model_path: str, unet: UNet2DConditionModel) -> Dic conditioning_handler = SDXLConditioningHandler(config) config['conditioning_spec'] = conditioning_handler.get_conditioning_spec() - return config + return config \ No newline at end of file