diff --git a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx index 37d25869c..0c7cbc33f 100644 --- a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx +++ b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx @@ -1,20 +1,13 @@ import React, { useCallback, useEffect, useRef, useState } from "react"; -import { TLShapeId, createShapeId, useEditor, useValue } from "tldraw"; +import { TLShapeId, useEditor, useValue } from "tldraw"; import { DiscourseNodeShape } from "~/components/canvas/DiscourseNodeUtil"; import { - BaseDiscourseRelationUtil, - DiscourseRelationShape, - getRelationColor, -} from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil"; -import { createOrUpdateArrowBinding } from "~/components/canvas/DiscourseRelationShape/helpers"; -import { - checkConnectionType, - getAllRelations, hasValidRelationTypes, isDiscourseNodeShape, } from "~/components/canvas/canvasUtils"; import { dispatchToastEvent } from "~/components/canvas/ToastListener"; import { RelationTypeDropdown } from "./RelationTypeDropdown"; +import { createRelationBetweenNodes } from "./relationCreation"; const HANDLE_RADIUS = 5; const HANDLE_HIT_AREA = 12; @@ -256,103 +249,13 @@ export const DragHandleOverlay = () => { (relationId: string) => { if (!pending) return; - const selectedRelation = getAllRelations().find( - (r) => r.id === relationId, - ); - if (!selectedRelation) { - setPending(null); - sourceNodeRef.current = null; - return; - } - - const color = getRelationColor(selectedRelation.label); - - // Determine direction: if we dragged from the relation's destination type, - // the arrow is in reverse and should display the complement label. - const sourceNode = editor.getShape(pending.sourceId); - const targetNode = editor.getShape(pending.targetId); - const { isReverse } = checkConnectionType( - selectedRelation, - sourceNode?.type ?? "", - targetNode?.type ?? "", - ); - const label = - isReverse && selectedRelation.complement - ? selectedRelation.complement - : selectedRelation.label; - - // Get source bounds for arrow positioning - const sourceBounds = editor.getShapePageBounds(pending.sourceId); - if (!sourceBounds) { - setPending(null); - sourceNodeRef.current = null; - return; - } - - // Create the real relation shape with the correct type - const arrowId = createShapeId(); - editor.createShape({ - id: arrowId, - type: relationId, - x: sourceBounds.midX, - y: sourceBounds.midY, - props: { - color, - text: label, - dash: "draw", - size: "m", - fill: "none", - bend: 0, - start: { x: 0, y: 0 }, - end: { x: 0, y: 0 }, - arrowheadStart: "none", - arrowheadEnd: "arrow", - labelPosition: 0.5, - font: "draw", - scale: 1, - }, - }); - - const newArrow = editor.getShape(arrowId); - if (!newArrow) { - setPending(null); - sourceNodeRef.current = null; - return; - } - - // Bind start and end - createOrUpdateArrowBinding(editor, newArrow, pending.sourceId, { - terminal: "start", - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - isExact: false, + void createRelationBetweenNodes({ + editor, + relationId, + sourceId: pending.sourceId, + targetId: pending.targetId, }); - createOrUpdateArrowBinding(editor, newArrow, pending.targetId, { - terminal: "end", - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - isExact: false, - }); - - // Persist via handleCreateRelationsInRoam - const util = editor.getShapeUtil(newArrow); - if ( - util instanceof BaseDiscourseRelationUtil && - "handleCreateRelationsInRoam" in util - ) { - type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { - handleCreateRelationsInRoam: (args: { - arrow: DiscourseRelationShape; - targetId: TLShapeId; - }) => Promise; - }; - void (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ - arrow: editor.getShape(arrowId) ?? newArrow, - targetId: pending.targetId, - }); - } - editor.select(arrowId); setPending(null); sourceNodeRef.current = null; }, diff --git a/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx b/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx index 758a94033..f4997df55 100644 --- a/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx +++ b/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx @@ -1,11 +1,6 @@ import React, { useCallback, useEffect, useMemo, useRef } from "react"; -import { TLShapeId, useEditor, DefaultColorThemePalette } from "tldraw"; -import { getRelationColor } from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil"; -import { - checkConnectionType, - getAllRelations, - isDiscourseNodeShape, -} from "~/components/canvas/canvasUtils"; +import { TLShapeId, useEditor } from "tldraw"; +import { getValidRelationTypesBetween } from "./relationCreation"; type RelationTypeDropdownProps = { sourceId: TLShapeId; @@ -25,49 +20,10 @@ export const RelationTypeDropdown = ({ const editor = useEditor(); const dropdownRef = useRef(null); - // Get valid relation types based on source/target node types - const validRelationTypes = useMemo(() => { - const startNode = editor.getShape(sourceId); - const endNode = editor.getShape(targetId); - if (!startNode || !endNode) return []; - - const startNodeType = startNode.type; - const endNodeType = endNode.type; - - // Verify both are discourse nodes - if ( - !isDiscourseNodeShape(editor, startNode) || - !isDiscourseNodeShape(editor, endNode) - ) - return []; - - const colorPalette = DefaultColorThemePalette.lightMode; - const validTypes: { id: string; label: string; color: string }[] = []; - const allRelations = getAllRelations(); - const seenLabels = new Set(); - - for (const relation of allRelations) { - const { isDirect: isForward, isReverse } = checkConnectionType( - relation, - startNodeType, - endNodeType, - ); - - if (!isForward && !isReverse) continue; - - const label = - isReverse && relation.complement ? relation.complement : relation.label; - - if (!seenLabels.has(label)) { - seenLabels.add(label); - const tldrawColor = getRelationColor(relation.label); - const hexColor = colorPalette[tldrawColor]?.solid ?? "#333"; - validTypes.push({ id: relation.id, label, color: hexColor }); - } - } - - return validTypes; - }, [editor, sourceId, targetId]); + const validRelationTypes = useMemo( + () => getValidRelationTypesBetween(editor, sourceId, targetId), + [editor, sourceId, targetId], + ); // Handle click outside useEffect(() => { diff --git a/apps/roam/src/components/canvas/overlays/relationCreation.ts b/apps/roam/src/components/canvas/overlays/relationCreation.ts new file mode 100644 index 000000000..12256f887 --- /dev/null +++ b/apps/roam/src/components/canvas/overlays/relationCreation.ts @@ -0,0 +1,150 @@ +import { + DefaultColorThemePalette, + Editor, + TLShapeId, + createShapeId, +} from "tldraw"; +import { + BaseDiscourseRelationUtil, + DiscourseRelationShape, + getRelationColor, +} from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil"; +import { createOrUpdateArrowBinding } from "~/components/canvas/DiscourseRelationShape/helpers"; +import { + checkConnectionType, + getAllRelations, + isDiscourseNodeShape, +} from "~/components/canvas/canvasUtils"; + +type RelationTypeOption = { id: string; label: string; color: string }; + +export const getValidRelationTypesBetween = ( + editor: Editor, + startId: TLShapeId, + endId: TLShapeId, +): RelationTypeOption[] => { + const startNode = editor.getShape(startId); + const endNode = editor.getShape(endId); + if (!startNode || !endNode) return []; + if ( + !isDiscourseNodeShape(editor, startNode) || + !isDiscourseNodeShape(editor, endNode) + ) + return []; + + const colorPalette = DefaultColorThemePalette.lightMode; + const validTypes: RelationTypeOption[] = []; + const seenLabels = new Set(); + + for (const relation of getAllRelations()) { + const { isDirect, isReverse } = checkConnectionType( + relation, + startNode.type, + endNode.type, + ); + if (!isDirect && !isReverse) continue; + + const label = + isReverse && relation.complement ? relation.complement : relation.label; + if (seenLabels.has(label)) continue; + seenLabels.add(label); + + const hexColor = + colorPalette[getRelationColor(relation.label)]?.solid ?? "#333"; + validTypes.push({ id: relation.id, label, color: hexColor }); + } + + return validTypes; +}; + +export const createRelationBetweenNodes = async ({ + editor, + relationId, + sourceId, + targetId, +}: { + editor: Editor; + relationId: string; + sourceId: TLShapeId; + targetId: TLShapeId; +}): Promise => { + const selectedRelation = getAllRelations().find((r) => r.id === relationId); + if (!selectedRelation) return null; + + const sourceNode = editor.getShape(sourceId); + const targetNode = editor.getShape(targetId); + const { isReverse } = checkConnectionType( + selectedRelation, + sourceNode?.type ?? "", + targetNode?.type ?? "", + ); + const label = + isReverse && selectedRelation.complement + ? selectedRelation.complement + : selectedRelation.label; + + const sourceBounds = editor.getShapePageBounds(sourceId); + if (!sourceBounds) return null; + + const arrowId = createShapeId(); + editor.createShape({ + id: arrowId, + type: relationId, + x: sourceBounds.midX, + y: sourceBounds.midY, + props: { + color: getRelationColor(selectedRelation.label), + text: label, + dash: "draw", + size: "m", + fill: "none", + bend: 0, + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, + arrowheadStart: "none", + arrowheadEnd: "arrow", + labelPosition: 0.5, + font: "draw", + scale: 1, + }, + }); + + const newArrow = editor.getShape(arrowId); + if (!newArrow) return null; + + createOrUpdateArrowBinding(editor, newArrow, sourceId, { + terminal: "start", + normalizedAnchor: { x: 0.5, y: 0.5 }, + isPrecise: false, + isExact: false, + }); + createOrUpdateArrowBinding(editor, newArrow, targetId, { + terminal: "end", + normalizedAnchor: { x: 0.5, y: 0.5 }, + isPrecise: false, + isExact: false, + }); + + editor.select(arrowId); + + const util = editor.getShapeUtil(newArrow); + if ( + util instanceof BaseDiscourseRelationUtil && + "handleCreateRelationsInRoam" in util + ) { + type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { + handleCreateRelationsInRoam: (args: { + arrow: DiscourseRelationShape; + targetId: TLShapeId; + }) => Promise; + }; + await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ + arrow: editor.getShape(arrowId) ?? newArrow, + targetId, + }); + } + + // handleCreateRelationsInRoam deletes the new arrow if it rejects the + // conversion, so a surviving shape means the relation was persisted. + return editor.getShape(arrowId) ? arrowId : null; +}; diff --git a/apps/roam/src/components/canvas/uiOverrides.tsx b/apps/roam/src/components/canvas/uiOverrides.tsx index 38025d810..d78e2fa0c 100644 --- a/apps/roam/src/components/canvas/uiOverrides.tsx +++ b/apps/roam/src/components/canvas/uiOverrides.tsx @@ -1,7 +1,9 @@ import React, { ReactElement } from "react"; import { + TLArrowBinding, TLImageShape, TLShape, + TLShapeId, TLTextShape, TLUiDialogProps, TLUiOverrides, @@ -49,6 +51,11 @@ import { COLOR_ARRAY } from "./DiscourseNodeUtil"; import calcCanvasNodeSizeAndImg from "~/utils/calcCanvasNodeSizeAndImg"; import { AddReferencedNodeType } from "./DiscourseRelationShape/DiscourseRelationTool"; import { getRelationColor } from "./DiscourseRelationShape/DiscourseRelationUtil"; +import { + createRelationBetweenNodes, + getValidRelationTypesBetween, +} from "./overlays/relationCreation"; +import { isDiscourseNodeShape } from "./canvasUtils"; import DiscourseGraphPanel from "./DiscourseToolPanel"; import type { CanvasNodeShortcuts } from "~/components/settings/utils/zodSchema"; import { CustomDefaultToolbar } from "./CustomDefaultToolbar"; @@ -224,6 +231,27 @@ export const getOnSelectForShape = ({ return () => {}; }; +const getArrowBoundNodeIds = ( + editor: Editor, + arrow: TLShape, +): { startId: TLShapeId; endId: TLShapeId } | null => { + const bindings = editor.getBindingsFromShape(arrow, "arrow"); + const startId = bindings.find((b) => b.props.terminal === "start")?.toId; + const endId = bindings.find((b) => b.props.terminal === "end")?.toId; + if (!startId || !endId) return null; + + const startShape = editor.getShape(startId); + const endShape = editor.getShape(endId); + if (!startShape || !endShape) return null; + if ( + !isDiscourseNodeShape(editor, startShape) || + !isDiscourseNodeShape(editor, endShape) + ) + return null; + + return { startId, endId }; +}; + export const CustomContextMenu = ({ extensionAPI, allNodes, @@ -239,6 +267,22 @@ export const CustomContextMenu = ({ ); const isTextSelected = selectedShape?.type === "text"; const isImageSelected = selectedShape?.type === "image"; + const arrowRelationOptions = useValue( + "arrowRelationOptions", + () => { + if (!selectedShape || selectedShape.type !== "arrow") return null; + const boundNodes = getArrowBoundNodeIds(editor, selectedShape); + if (!boundNodes) return null; + const relationTypes = getValidRelationTypesBetween( + editor, + boundNodes.startId, + boundNodes.endId, + ); + if (relationTypes.length === 0) return null; + return { arrowId: selectedShape.id, ...boundNodes, relationTypes }; + }, + [editor, selectedShape], + ); return ( @@ -268,6 +312,30 @@ export const CustomContextMenu = ({ )} + {arrowRelationOptions && ( + + + {arrowRelationOptions.relationTypes.map((rt) => ( + { + const newArrowId = await createRelationBetweenNodes({ + editor, + relationId: rt.id, + sourceId: arrowRelationOptions.startId, + targetId: arrowRelationOptions.endId, + }); + if (newArrowId) { + editor.deleteShapes([arrowRelationOptions.arrowId]); + } + }} + /> + ))} + + + )} ); };