Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 7 additions & 104 deletions apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import React, { useCallback, useEffect, useRef, useState } from "react";
import { TLShapeId, createShapeId, useEditor, useValue } from "tldraw";
import { TLShapeId, useEditor, useValue } from "tldraw";
import { DiscourseNodeShape } from "~/components/canvas/DiscourseNodeUtil";
import {
BaseDiscourseRelationUtil,
DiscourseRelationShape,
getRelationColor,
} from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil";
import { createOrUpdateArrowBinding } from "~/components/canvas/DiscourseRelationShape/helpers";
import {
checkConnectionType,
getAllRelations,
hasValidRelationTypes,
isDiscourseNodeShape,
} from "~/components/canvas/canvasUtils";
import { dispatchToastEvent } from "~/components/canvas/ToastListener";
import { RelationTypeDropdown } from "./RelationTypeDropdown";
import { createRelationBetweenNodes } from "./relationCreation";

const HANDLE_RADIUS = 5;
const HANDLE_HIT_AREA = 12;
Expand Down Expand Up @@ -256,103 +249,13 @@ export const DragHandleOverlay = () => {
(relationId: string) => {
if (!pending) return;

const selectedRelation = getAllRelations().find(
(r) => r.id === relationId,
);
if (!selectedRelation) {
setPending(null);
sourceNodeRef.current = null;
return;
}

const color = getRelationColor(selectedRelation.label);

// Determine direction: if we dragged from the relation's destination type,
// the arrow is in reverse and should display the complement label.
const sourceNode = editor.getShape(pending.sourceId);
const targetNode = editor.getShape(pending.targetId);
const { isReverse } = checkConnectionType(
selectedRelation,
sourceNode?.type ?? "",
targetNode?.type ?? "",
);
const label =
isReverse && selectedRelation.complement
? selectedRelation.complement
: selectedRelation.label;

// Get source bounds for arrow positioning
const sourceBounds = editor.getShapePageBounds(pending.sourceId);
if (!sourceBounds) {
setPending(null);
sourceNodeRef.current = null;
return;
}

// Create the real relation shape with the correct type
const arrowId = createShapeId();
editor.createShape<DiscourseRelationShape>({
id: arrowId,
type: relationId,
x: sourceBounds.midX,
y: sourceBounds.midY,
props: {
color,
text: label,
dash: "draw",
size: "m",
fill: "none",
bend: 0,
start: { x: 0, y: 0 },
end: { x: 0, y: 0 },
arrowheadStart: "none",
arrowheadEnd: "arrow",
labelPosition: 0.5,
font: "draw",
scale: 1,
},
});

const newArrow = editor.getShape<DiscourseRelationShape>(arrowId);
if (!newArrow) {
setPending(null);
sourceNodeRef.current = null;
return;
}

// Bind start and end
createOrUpdateArrowBinding(editor, newArrow, pending.sourceId, {
terminal: "start",
normalizedAnchor: { x: 0.5, y: 0.5 },
isPrecise: false,
isExact: false,
void createRelationBetweenNodes({
editor,
relationId,
sourceId: pending.sourceId,
targetId: pending.targetId,
});
createOrUpdateArrowBinding(editor, newArrow, pending.targetId, {
terminal: "end",
normalizedAnchor: { x: 0.5, y: 0.5 },
isPrecise: false,
isExact: false,
});

// Persist via handleCreateRelationsInRoam
const util = editor.getShapeUtil(newArrow);
if (
util instanceof BaseDiscourseRelationUtil &&
"handleCreateRelationsInRoam" in util
) {
type UtilWithRoamPersistence = BaseDiscourseRelationUtil & {
handleCreateRelationsInRoam: (args: {
arrow: DiscourseRelationShape;
targetId: TLShapeId;
}) => Promise<void>;
};
void (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({
arrow: editor.getShape<DiscourseRelationShape>(arrowId) ?? newArrow,
targetId: pending.targetId,
});
}

editor.select(arrowId);
setPending(null);
sourceNodeRef.current = null;
},
Expand Down
56 changes: 6 additions & 50 deletions apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import React, { useCallback, useEffect, useMemo, useRef } from "react";
import { TLShapeId, useEditor, DefaultColorThemePalette } from "tldraw";
import { getRelationColor } from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil";
import {
checkConnectionType,
getAllRelations,
isDiscourseNodeShape,
} from "~/components/canvas/canvasUtils";
import { TLShapeId, useEditor } from "tldraw";
import { getValidRelationTypesBetween } from "./relationCreation";

type RelationTypeDropdownProps = {
sourceId: TLShapeId;
Expand All @@ -25,49 +20,10 @@ export const RelationTypeDropdown = ({
const editor = useEditor();
const dropdownRef = useRef<HTMLDivElement>(null);

// Get valid relation types based on source/target node types
const validRelationTypes = useMemo(() => {
const startNode = editor.getShape(sourceId);
const endNode = editor.getShape(targetId);
if (!startNode || !endNode) return [];

const startNodeType = startNode.type;
const endNodeType = endNode.type;

// Verify both are discourse nodes
if (
!isDiscourseNodeShape(editor, startNode) ||
!isDiscourseNodeShape(editor, endNode)
)
return [];

const colorPalette = DefaultColorThemePalette.lightMode;
const validTypes: { id: string; label: string; color: string }[] = [];
const allRelations = getAllRelations();
const seenLabels = new Set<string>();

for (const relation of allRelations) {
const { isDirect: isForward, isReverse } = checkConnectionType(
relation,
startNodeType,
endNodeType,
);

if (!isForward && !isReverse) continue;

const label =
isReverse && relation.complement ? relation.complement : relation.label;

if (!seenLabels.has(label)) {
seenLabels.add(label);
const tldrawColor = getRelationColor(relation.label);
const hexColor = colorPalette[tldrawColor]?.solid ?? "#333";
validTypes.push({ id: relation.id, label, color: hexColor });
}
}

return validTypes;
}, [editor, sourceId, targetId]);
const validRelationTypes = useMemo(
() => getValidRelationTypesBetween(editor, sourceId, targetId),
[editor, sourceId, targetId],
);

// Handle click outside
useEffect(() => {
Expand Down
150 changes: 150 additions & 0 deletions apps/roam/src/components/canvas/overlays/relationCreation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import {
DefaultColorThemePalette,
Editor,
TLShapeId,
createShapeId,
} from "tldraw";
import {
BaseDiscourseRelationUtil,
DiscourseRelationShape,
getRelationColor,
} from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil";
import { createOrUpdateArrowBinding } from "~/components/canvas/DiscourseRelationShape/helpers";
import {
checkConnectionType,
getAllRelations,
isDiscourseNodeShape,
} from "~/components/canvas/canvasUtils";

type RelationTypeOption = { id: string; label: string; color: string };

export const getValidRelationTypesBetween = (
editor: Editor,
startId: TLShapeId,
endId: TLShapeId,
): RelationTypeOption[] => {
const startNode = editor.getShape(startId);
const endNode = editor.getShape(endId);
if (!startNode || !endNode) return [];
if (
!isDiscourseNodeShape(editor, startNode) ||
!isDiscourseNodeShape(editor, endNode)
)
return [];

const colorPalette = DefaultColorThemePalette.lightMode;
const validTypes: RelationTypeOption[] = [];
const seenLabels = new Set<string>();

for (const relation of getAllRelations()) {
const { isDirect, isReverse } = checkConnectionType(
relation,
startNode.type,
endNode.type,
);
if (!isDirect && !isReverse) continue;

const label =
isReverse && relation.complement ? relation.complement : relation.label;
if (seenLabels.has(label)) continue;
seenLabels.add(label);

const hexColor =
colorPalette[getRelationColor(relation.label)]?.solid ?? "#333";
validTypes.push({ id: relation.id, label, color: hexColor });
}

return validTypes;
};

export const createRelationBetweenNodes = async ({
editor,
relationId,
sourceId,
targetId,
}: {
editor: Editor;
relationId: string;
sourceId: TLShapeId;
targetId: TLShapeId;
}): Promise<TLShapeId | null> => {
const selectedRelation = getAllRelations().find((r) => r.id === relationId);
if (!selectedRelation) return null;

const sourceNode = editor.getShape(sourceId);
const targetNode = editor.getShape(targetId);
const { isReverse } = checkConnectionType(
selectedRelation,
sourceNode?.type ?? "",
targetNode?.type ?? "",
);
const label =
isReverse && selectedRelation.complement
? selectedRelation.complement
: selectedRelation.label;

const sourceBounds = editor.getShapePageBounds(sourceId);
if (!sourceBounds) return null;

const arrowId = createShapeId();
editor.createShape<DiscourseRelationShape>({
id: arrowId,
type: relationId,
x: sourceBounds.midX,
y: sourceBounds.midY,
props: {
color: getRelationColor(selectedRelation.label),
text: label,
dash: "draw",
size: "m",
fill: "none",
bend: 0,
start: { x: 0, y: 0 },
end: { x: 0, y: 0 },
arrowheadStart: "none",
arrowheadEnd: "arrow",
labelPosition: 0.5,
font: "draw",
scale: 1,
},
});

const newArrow = editor.getShape<DiscourseRelationShape>(arrowId);
if (!newArrow) return null;

createOrUpdateArrowBinding(editor, newArrow, sourceId, {
terminal: "start",
normalizedAnchor: { x: 0.5, y: 0.5 },
isPrecise: false,
isExact: false,
});
createOrUpdateArrowBinding(editor, newArrow, targetId, {
terminal: "end",
normalizedAnchor: { x: 0.5, y: 0.5 },
isPrecise: false,
isExact: false,
});

editor.select(arrowId);

const util = editor.getShapeUtil(newArrow);
if (
util instanceof BaseDiscourseRelationUtil &&
"handleCreateRelationsInRoam" in util
) {
type UtilWithRoamPersistence = BaseDiscourseRelationUtil & {
handleCreateRelationsInRoam: (args: {
arrow: DiscourseRelationShape;
targetId: TLShapeId;
}) => Promise<void>;
};
await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({
arrow: editor.getShape<DiscourseRelationShape>(arrowId) ?? newArrow,
targetId,
});
}

// handleCreateRelationsInRoam deletes the new arrow if it rejects the
// conversion, so a surviving shape means the relation was persisted.
return editor.getShape(arrowId) ? arrowId : null;
};
Loading