Skip to content

Commit 8537a60

Browse files
authored
Merge pull request #52 from ForgeOpus/resnet_update
fixed custom block ui
2 parents a065c37 + d834f29 commit 8537a60

5 files changed

Lines changed: 45 additions & 33 deletions

File tree

project/block_manager/services/nodes/pytorch/batchnorm2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ class BatchNorm2DNode(NodeDefinition):
1010
@property
1111
def metadata(self) -> NodeMetadata:
1212
return NodeMetadata(
13-
type="batchnorm2d",
14-
label="BatchNorm2D",
13+
type="batchnorm",
14+
label="Batch Normalization",
1515
category="basic",
16-
color="var(--color-primary)",
16+
color="var(--color-accent)",
1717
icon="ChartLineUp",
18-
description="Batch normalization for 2D inputs",
18+
description="Batch normalization layer",
1919
framework=Framework.PYTORCH
2020
)
2121

project/block_manager/services/nodes/tensorflow/batchnorm2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ class BatchNorm2DNode(NodeDefinition):
1010
@property
1111
def metadata(self) -> NodeMetadata:
1212
return NodeMetadata(
13-
type="batchnorm2d",
14-
label="BatchNorm2D",
13+
type="batchnorm",
14+
label="Batch Normalization",
1515
category="basic",
16-
color="var(--color-orange)",
17-
icon="Zap",
16+
color="var(--color-accent)",
17+
icon="ChartLineUp",
1818
description="Batch normalization layer",
1919
framework=Framework.TENSORFLOW
2020
)

project/block_manager/services/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _validate_shape_compatibility(self):
324324
incoming = edge_map.get(node_id, [])
325325

326326
# Check that nodes with required inputs have connections
327-
if node_type in ('conv2d', 'linear', 'maxpool2d', 'maxpool', 'batchnorm', 'batchnorm2d', 'flatten'):
327+
if node_type in ('conv2d', 'linear', 'maxpool2d', 'maxpool', 'batchnorm', 'flatten'):
328328
if not incoming:
329329
self.errors.append(ValidationError(
330330
message=f'{node_type} layer requires an input connection',

project/frontend/src/components/GroupBlockNode.tsx

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,10 @@ const GroupBlockNode = memo(({ data, selected, id }: GroupBlockNodeProps) => {
154154

155155
{/* Render input handles */}
156156
{inputPorts.map((port, index) => {
157-
const spacing = 100 / (inputPorts.length + 1)
158-
const topPercent = spacing * (index + 1)
157+
const rangeStart = 70
158+
const rangeEnd = 90
159+
const spacing = (rangeEnd - rangeStart) / (inputPorts.length + 1)
160+
const topPercent = rangeStart + spacing * (index + 1)
159161
const color = getPortColor(port.semantic)
160162
const isConnected = isHandleConnected(port.externalPortId, true)
161163

@@ -212,45 +214,46 @@ const GroupBlockNode = memo(({ data, selected, id }: GroupBlockNodeProps) => {
212214
)
213215
})}
214216

215-
<div className="p-3 space-y-2">
217+
<div className="p-3">
216218
<div className="flex items-center gap-2">
217219
<div
218-
className="p-1.5 rounded"
220+
className="p-1 rounded shrink-0"
219221
style={{
220222
backgroundColor: groupDef.color,
221223
color: 'white'
222224
}}
223225
>
224-
<Icons.SquaresFour size={16} weight="bold" />
226+
<Icons.SquaresFour size={14} weight="bold" />
225227
</div>
226228
<div className="flex-1 min-w-0">
227-
<div className="font-semibold text-sm truncate">
229+
<div className="font-semibold text-sm truncate leading-tight">
228230
{groupDef.name}
229231
</div>
230-
<div className="flex items-center gap-1">
231-
<Badge
232-
variant="secondary"
233-
className="text-[9px] px-1 py-0 h-3.5"
234-
>
235-
{groupDef.category}
236-
</Badge>
237-
<Badge
238-
variant="outline"
239-
className="text-[9px] px-1 py-0 h-3.5"
240-
>
241-
{groupDef.internalNodes.length} nodes
242-
</Badge>
243-
</div>
244232
</div>
245233
</div>
246234

235+
<div className="flex items-center gap-1 mt-1">
236+
<Badge
237+
variant="secondary"
238+
className="text-[9px] px-1 py-0 h-3.5"
239+
>
240+
{groupDef.category}
241+
</Badge>
242+
<Badge
243+
variant="outline"
244+
className="text-[9px] px-1 py-0 h-3.5"
245+
>
246+
{groupDef.internalNodes.length} nodes
247+
</Badge>
248+
</div>
249+
247250
{groupDef.description && (
248-
<div className="text-[10px] text-muted-foreground line-clamp-2">
251+
<div className="text-[10px] text-muted-foreground line-clamp-2 mt-1">
249252
{groupDef.description}
250253
</div>
251254
)}
252255

253-
<div className="flex items-center gap-1 text-[10px] text-muted-foreground">
256+
<div className="flex items-center gap-1 text-[10px] text-muted-foreground mt-1">
254257
<Icons.ArrowsIn size={12} />
255258
<span>{inputPorts.length} in</span>
256259
<span className="mx-1"></span>
@@ -261,8 +264,10 @@ const GroupBlockNode = memo(({ data, selected, id }: GroupBlockNodeProps) => {
261264

262265
{/* Render output handles */}
263266
{outputPorts.map((port, index) => {
264-
const spacing = 100 / (outputPorts.length + 1)
265-
const topPercent = spacing * (index + 1)
267+
const rangeStart = 70
268+
const rangeEnd = 90
269+
const spacing = (rangeEnd - rangeStart) / (outputPorts.length + 1)
270+
const topPercent = rangeStart + spacing * (index + 1)
266271
const color = getPortColor(port.semantic)
267272
const isConnected = isHandleConnected(port.externalPortId, false)
268273

project/frontend/src/index.css

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ body {
120120
font-family: var(--font-sans);
121121
}
122122

123+
.react-flow__node-group {
124+
background: transparent !important;
125+
border: none !important;
126+
padding: 0 !important;
127+
box-shadow: none !important;
128+
}
129+
123130
code, pre {
124131
font-family: var(--font-mono);
125132
}

0 commit comments

Comments
 (0)