diff --git a/project/block_manager/services/nodes/pytorch/batchnorm2d.py b/project/block_manager/services/nodes/pytorch/batchnorm2d.py index 0ab1f9f..4458030 100644 --- a/project/block_manager/services/nodes/pytorch/batchnorm2d.py +++ b/project/block_manager/services/nodes/pytorch/batchnorm2d.py @@ -10,12 +10,12 @@ class BatchNorm2DNode(NodeDefinition): @property def metadata(self) -> NodeMetadata: return NodeMetadata( - type="batchnorm2d", - label="BatchNorm2D", + type="batchnorm", + label="Batch Normalization", category="basic", - color="var(--color-primary)", + color="var(--color-accent)", icon="ChartLineUp", - description="Batch normalization for 2D inputs", + description="Batch normalization layer", framework=Framework.PYTORCH ) diff --git a/project/block_manager/services/nodes/tensorflow/batchnorm2d.py b/project/block_manager/services/nodes/tensorflow/batchnorm2d.py index 5738056..17bea29 100644 --- a/project/block_manager/services/nodes/tensorflow/batchnorm2d.py +++ b/project/block_manager/services/nodes/tensorflow/batchnorm2d.py @@ -10,11 +10,11 @@ class BatchNorm2DNode(NodeDefinition): @property def metadata(self) -> NodeMetadata: return NodeMetadata( - type="batchnorm2d", - label="BatchNorm2D", + type="batchnorm", + label="Batch Normalization", category="basic", - color="var(--color-orange)", - icon="Zap", + color="var(--color-accent)", + icon="ChartLineUp", description="Batch normalization layer", framework=Framework.TENSORFLOW ) diff --git a/project/block_manager/services/validation.py b/project/block_manager/services/validation.py index da5079b..6b955b4 100644 --- a/project/block_manager/services/validation.py +++ b/project/block_manager/services/validation.py @@ -324,7 +324,7 @@ def _validate_shape_compatibility(self): incoming = edge_map.get(node_id, []) # Check that nodes with required inputs have connections - if node_type in ('conv2d', 'linear', 'maxpool2d', 'maxpool', 'batchnorm', 'batchnorm2d', 'flatten'): + if node_type in ('conv2d', 'linear', 'maxpool2d', 'maxpool', 'batchnorm', 'flatten'): if not incoming: self.errors.append(ValidationError( message=f'{node_type} layer requires an input connection', diff --git a/project/frontend/src/components/GroupBlockNode.tsx b/project/frontend/src/components/GroupBlockNode.tsx index 2899d4d..b4af99b 100644 --- a/project/frontend/src/components/GroupBlockNode.tsx +++ b/project/frontend/src/components/GroupBlockNode.tsx @@ -154,8 +154,10 @@ const GroupBlockNode = memo(({ data, selected, id }: GroupBlockNodeProps) => { {/* Render input handles */} {inputPorts.map((port, index) => { - const spacing = 100 / (inputPorts.length + 1) - const topPercent = spacing * (index + 1) + const rangeStart = 70 + const rangeEnd = 90 + const spacing = (rangeEnd - rangeStart) / (inputPorts.length + 1) + const topPercent = rangeStart + spacing * (index + 1) const color = getPortColor(port.semantic) const isConnected = isHandleConnected(port.externalPortId, true) @@ -212,45 +214,46 @@ const GroupBlockNode = memo(({ data, selected, id }: GroupBlockNodeProps) => { ) })} -