diff --git a/src/slang-nodes/Expression.ts b/src/slang-nodes/Expression.ts index c701efd74..684498d11 100644 --- a/src/slang-nodes/Expression.ts +++ b/src/slang-nodes/Expression.ts @@ -106,7 +106,11 @@ export class Expression extends SlangNode { | ElementaryType['variant'] | TerminalNode; - constructor(ast: ast.Expression, collected: CollectedMetadata) { + constructor( + ast: ast.Expression, + collected: CollectedMetadata, + endOfChain?: boolean + ) { super(ast, collected); const variant = ast.variant; @@ -114,7 +118,7 @@ export class Expression extends SlangNode { this.variant = new TerminalNode(variant, collected); return; } - this.variant = createNonterminalVariant(variant, collected); + this.variant = createNonterminalVariant(variant, collected, endOfChain); this.updateMetadata(this.variant); } diff --git a/src/slang-nodes/FunctionCallExpression.ts b/src/slang-nodes/FunctionCallExpression.ts index 79945a073..94de5cd57 100644 --- a/src/slang-nodes/FunctionCallExpression.ts +++ b/src/slang-nodes/FunctionCallExpression.ts @@ -16,10 +16,16 @@ export class FunctionCallExpression extends SlangNode { arguments: ArgumentsDeclaration['variant']; - constructor(ast: ast.FunctionCallExpression, collected: CollectedMetadata) { + constructor( + ast: ast.FunctionCallExpression, + collected: CollectedMetadata, + endOfChain?: boolean + ) { super(ast, collected); - this.operand = extractVariant(new Expression(ast.operand, collected)); + this.operand = extractVariant( + new Expression(ast.operand, collected, endOfChain) + ); this.arguments = extractVariant( new ArgumentsDeclaration(ast.arguments, collected) ); diff --git a/src/slang-nodes/IndexAccessExpression.ts b/src/slang-nodes/IndexAccessExpression.ts index 6f387854a..f7f18c8de 100644 --- a/src/slang-nodes/IndexAccessExpression.ts +++ b/src/slang-nodes/IndexAccessExpression.ts @@ -19,10 +19,16 @@ export class IndexAccessExpression extends SlangNode { end?: IndexAccessEnd; - constructor(ast: ast.IndexAccessExpression, collected: CollectedMetadata) { + constructor( + ast: ast.IndexAccessExpression, + collected: CollectedMetadata, + endOfChain?: boolean + ) { super(ast, collected); - this.operand = extractVariant(new Expression(ast.operand, collected)); + this.operand = extractVariant( + new Expression(ast.operand, collected, endOfChain) + ); if (ast.start) { this.start = extractVariant(new Expression(ast.start, collected)); } diff --git a/src/slang-nodes/MemberAccessExpression.ts b/src/slang-nodes/MemberAccessExpression.ts index cc97f126f..f19208c0e 100644 --- a/src/slang-nodes/MemberAccessExpression.ts +++ b/src/slang-nodes/MemberAccessExpression.ts @@ -2,51 +2,19 @@ import { NonterminalKind } from '@nomicfoundation/slang/cst'; import { doc } from 'prettier'; import { isLabel } from '../slang-utils/is-label.js'; import { extractVariant } from '../slang-utils/extract-variant.js'; -import { isChainableExpression } from '../slang-utils/is-chainable-expression.js'; import { memberAccessChainLabel } from '../slang-printers/print-member-access-chain-item.js'; import { SlangNode } from './SlangNode.js'; import { Expression } from './Expression.js'; import { TerminalNode } from './TerminalNode.js'; import type * as ast from '@nomicfoundation/slang/ast'; -import type { AstPath, Doc } from 'prettier'; +import type { Doc } from 'prettier'; import type { CollectedMetadata, PrintFunction } from '../types.d.ts'; -import type { ChainableExpression, PrintableNode } from './types.d.ts'; const { group, indent, label, softline } = doc.builders; const separatorLabel = Symbol('separator'); -function isEndOfChain( - node: ChainableExpression, - path: AstPath -): boolean { - for (let i = 1, current = node, parent; ; i++, current = parent) { - parent = path.getNode(i)!; - if (!isChainableExpression(parent)) break; - - switch (parent.kind) { - case NonterminalKind.MemberAccessExpression: - // If `parent` is a MemberAccessExpression we are not at the end - // of the chain. - return false; - case NonterminalKind.IndexAccessExpression: - // If `parent` is an IndexAccessExpression and `current` is not - // the operand then it must be the start or the end in which case it is - // the end of the chain. - if (current !== parent.operand) return true; - break; - case NonterminalKind.FunctionCallExpression: - // If `parent` is a FunctionCallExpression and `current` is not - // the operand then it must be and argument in which case it is the end - // of the chain. - if (current !== parent.operand) return true; - break; - } - } - return true; -} - /** * processChain expects the doc[] of the full chain of MemberAccess. * @@ -111,20 +79,30 @@ function processChain(chain: Doc[]): Doc { export class MemberAccessExpression extends SlangNode { readonly kind = NonterminalKind.MemberAccessExpression; + readonly #endOfChain: boolean; + operand: Expression['variant']; member: TerminalNode; - constructor(ast: ast.MemberAccessExpression, collected: CollectedMetadata) { + constructor( + ast: ast.MemberAccessExpression, + collected: CollectedMetadata, + endOfChain = true + ) { super(ast, collected); - this.operand = extractVariant(new Expression(ast.operand, collected)); + this.operand = extractVariant( + new Expression(ast.operand, collected, false) + ); this.member = new TerminalNode(ast.member, collected); this.updateMetadata(this.operand); + + this.#endOfChain = endOfChain; } - print(print: PrintFunction, path: AstPath): Doc { + print(print: PrintFunction): Doc { let operandDoc = print('operand'); if (Array.isArray(operandDoc)) { operandDoc = operandDoc.flat(); @@ -136,6 +114,6 @@ export class MemberAccessExpression extends SlangNode { print('member') ].flat(); - return isEndOfChain(this, path) ? processChain(document) : document; + return this.#endOfChain ? processChain(document) : document; } } diff --git a/src/slang-printers/print-assignment-right-side.ts b/src/slang-printers/print-assignment-right-side.ts index d75958f67..6deb22d32 100644 --- a/src/slang-printers/print-assignment-right-side.ts +++ b/src/slang-printers/print-assignment-right-side.ts @@ -1,9 +1,19 @@ import { NonterminalKind } from '@nomicfoundation/slang/cst'; -import { isChainableExpression } from '../slang-utils/is-chainable-expression.js'; +import { createKindCheckFunction } from '../slang-utils/create-kind-check-function.js'; import { printIndentedGroupOrSpacedDocument } from './print-indented-group-or-spaced-document.js'; import type { Doc, doc } from 'prettier'; import type { Expression } from '../slang-nodes/Expression.ts'; +import type { + ChainableExpression, + PrintableNode +} from '../slang-nodes/types.d.ts'; + +const isChainableExpression = createKindCheckFunction([ + NonterminalKind.FunctionCallExpression, + NonterminalKind.IndexAccessExpression, + NonterminalKind.MemberAccessExpression +]) as (node: PrintableNode) => node is ChainableExpression; export function printAssignmentRightSide( document: Doc, diff --git a/src/slang-utils/create-nonterminal-variant-creator.ts b/src/slang-utils/create-nonterminal-variant-creator.ts index 9dd7b0c82..cbdc01641 100644 --- a/src/slang-utils/create-nonterminal-variant-creator.ts +++ b/src/slang-utils/create-nonterminal-variant-creator.ts @@ -8,13 +8,21 @@ import type { SlangAstNodeClass } from '../types.d.ts'; -type NodeConstructor = new (ast: any, collected: CollectedMetadata) => T; +type NodeConstructor = new ( + ast: any, + collected: CollectedMetadata, + endOfChain?: boolean +) => T; type SlangPolymorphicNode = Extract; type NonterminalVariantFactory< U extends SlangPolymorphicNode, T extends StrictPolymorphicNode -> = (variant: U['variant'], collected: CollectedMetadata) => T['variant']; +> = ( + variant: U['variant'], + collected: CollectedMetadata, + endOfChain?: boolean +) => T['variant']; export function createNonterminalVariantSimpleCreator< U extends SlangPolymorphicNode, @@ -22,10 +30,10 @@ export function createNonterminalVariantSimpleCreator< >( constructors: [SlangAstNodeClass, NodeConstructor][] ): NonterminalVariantFactory { - return (variant, collected) => { + return (variant, collected, endOfChain) => { for (const [slangAstClass, constructor] of constructors) { if (variant instanceof slangAstClass) - return new constructor(variant, collected); + return new constructor(variant, collected, endOfChain); } throw new Error(`Unexpected variant: ${JSON.stringify(variant)}`); @@ -46,12 +54,12 @@ export function createNonterminalVariantCreator< constructors ); - return (variant, collected) => { + return (variant, collected, endOfChain) => { for (const [slangAstClass, constructor] of extractVariantConstructors) { if (variant instanceof slangAstClass) - return extractVariant(new constructor(variant, collected)); + return extractVariant(new constructor(variant, collected, endOfChain)); } - return simpleCreator(variant, collected); + return simpleCreator(variant, collected, endOfChain); }; } diff --git a/src/slang-utils/is-chainable-expression.ts b/src/slang-utils/is-chainable-expression.ts deleted file mode 100644 index 7a7b13137..000000000 --- a/src/slang-utils/is-chainable-expression.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { NonterminalKind } from '@nomicfoundation/slang/cst'; -import { createKindCheckFunction } from './create-kind-check-function.js'; - -import type { - ChainableExpression, - PrintableNode -} from '../slang-nodes/types.d.ts'; - -export const isChainableExpression = createKindCheckFunction([ - NonterminalKind.FunctionCallExpression, - NonterminalKind.IndexAccessExpression, - NonterminalKind.MemberAccessExpression -]) as (node: PrintableNode) => node is ChainableExpression;