This repository was archived by the owner on Dec 26, 2025. It is now read-only.
forked from pschroedl/StreamDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathimage_processing_module.py
More file actions
153 lines (118 loc) · 6.63 KB
/
image_processing_module.py
File metadata and controls
153 lines (118 loc) · 6.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import List, Optional, Any, Dict
import torch
from ..preprocessing.orchestrator_user import OrchestratorUser
from ..preprocessing.pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator
from ..hooks import ImageCtx, ImageHook
class ImageProcessingModule(OrchestratorUser):
"""
Shared base class for image domain processing modules.
Handles sequential chain execution for both preprocessing and postprocessing
timing variants. Processing domain is always image tensors.
"""
def __init__(self):
"""Initialize image processing module."""
self.processors = []
def _process_image_chain(self, input_image: torch.Tensor) -> torch.Tensor:
"""Execute sequential chain of processors in image domain.
Uses the shared orchestrator's sequential chain processing.
"""
if not self.processors:
return input_image
ordered_processors = self._get_ordered_processors()
return self._preprocessing_orchestrator.execute_pipeline_chain(
input_image, ordered_processors, processing_domain="image"
)
def add_processor(self, proc_config: Dict[str, Any]) -> None:
"""Add a processor using the existing registry, following ControlNet pattern."""
from streamdiffusion.preprocessing.processors import get_preprocessor
processor_type = proc_config.get('type')
if not processor_type:
raise ValueError("Processor config missing 'type' field")
# Check if processor is enabled (default to True, same as ControlNet)
enabled = proc_config.get('enabled', True)
# Pass all processor params as constructor kwargs
processor_params = proc_config.get('params', {})
processor = get_preprocessor(processor_type, pipeline_ref=getattr(self, '_stream', None), **processor_params)
# Set order for sequential execution
order = proc_config.get('order', len(self.processors))
setattr(processor, 'order', order)
# Set enabled state
setattr(processor, 'enabled', enabled)
# Align preprocessor target size with stream resolution (same as ControlNet)
if hasattr(self, '_stream'):
try:
if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict):
processor.params['image_width'] = int(self._stream.width)
processor.params['image_height'] = int(self._stream.height)
if hasattr(processor, 'image_width'):
setattr(processor, 'image_width', int(self._stream.width))
if hasattr(processor, 'image_height'):
setattr(processor, 'image_height', int(self._stream.height))
except Exception:
pass
self.processors.append(processor)
def _get_ordered_processors(self) -> List[Any]:
"""Return enabled processors in execution order based on their order attribute."""
# Filter for enabled processors first, then sort by order
enabled_processors = [p for p in self.processors if getattr(p, 'enabled', True)]
return sorted(enabled_processors, key=lambda p: getattr(p, 'order', 0))
class ImagePreprocessingModule(ImageProcessingModule):
"""
Image domain preprocessing module - executes before VAE encoding.
Timing: After image_processor.preprocess(), before similar_image_filter
Uses pipelined processing for performance optimization.
"""
def install(self, stream) -> None:
"""Install module by registering hook with stream and attaching orchestrators."""
self._stream = stream # Store stream reference for dimension access
self.attach_orchestrator(stream) # For sequential chain processing (fallback)
self.attach_pipeline_preprocessing_orchestrator(stream) # For pipelined processing
stream.image_preprocessing_hooks.append(self.build_image_hook())
def build_image_hook(self) -> ImageHook:
"""Build hook function that processes image context with pipelined processing."""
def hook(ctx: ImageCtx) -> ImageCtx:
ctx.image = self._process_image_pipelined(ctx.image)
return ctx
return hook
def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor:
"""Execute pipelined processing of preprocessors for performance.
Uses PipelinePreprocessingOrchestrator for Frame N-1 results while starting Frame N processing.
Falls back to synchronous processing when needed.
"""
if not self.processors:
return input_image
ordered_processors = self._get_ordered_processors()
# Use pipelined pipeline preprocessing orchestrator for performance
return self._pipeline_preprocessing_orchestrator.process_pipelined(
input_image, ordered_processors
)
class ImagePostprocessingModule(ImageProcessingModule):
"""
Image domain postprocessing module - executes after VAE decoding.
Timing: After decode_image(), before returning final output
Uses pipelined processing for performance optimization.
"""
def install(self, stream) -> None:
"""Install module by registering hook with stream and attaching orchestrators."""
self._stream = stream # Store stream reference for dimension access
self.attach_preprocessing_orchestrator(stream) # For sequential chain processing (fallback)
self.attach_postprocessing_orchestrator(stream) # For pipelined processing
stream.image_postprocessing_hooks.append(self.build_image_hook())
def build_image_hook(self) -> ImageHook:
"""Build hook function that processes image context with pipelined processing."""
def hook(ctx: ImageCtx) -> ImageCtx:
ctx.image = self._process_image_pipelined(ctx.image)
return ctx
return hook
def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor:
"""Execute pipelined processing of postprocessors for performance.
Uses PostprocessingOrchestrator for Frame N-1 results while starting Frame N processing.
Falls back to synchronous processing when needed.
"""
if not self.processors:
return input_image
ordered_processors = self._get_ordered_processors()
# Use pipelined postprocessing orchestrator for performance
return self._postprocessing_orchestrator.process_pipelined(
input_image, ordered_processors
)