Plugin to group and store nodes by their type

Hi, I have a couple of nodes in my schema whose part is generated automatically based on its position in the doc. The examples are figure and math nodes where on each update, I need to update the number of a figure. I’m doing this in a custom NodeView were I have a function to simply get all nodes of that type and find an index in that array of the node in question. Here I came into conclusion that it would be nice to actually group all the nodes by their types before rerendering happens. Also, I need to trigger some nodes to rerender - here I used the decorations method. The plugin looks like this:

 addProseMirrorPlugins() {
    const { editor, options } = this;

    function compareStates(
      oldState: EditorState | undefined,
      newState: EditorState
    ): Omit<INodesState, 'decorations'> {
      let nodes: Record<string, NodeWithPos[]> = {};
      let refresh: Record<string, NodeWithPos[]> = {};

      if (!oldState || oldState.doc !== newState.doc || true) {
        const prevNodesById: Record<string, NodeWithPos> = {};
        oldState &&
          oldState.doc.descendants((node, pos) => {
            if (node.attrs.id) {
              prevNodesById[node.attrs.id] = { node, pos };
            }
          });

        const nextNodesById: Record<string, NodeWithPos> = {};
        newState.doc.descendants((node, pos, parent) => {
          if (node.attrs.id) {
            nextNodesById[node.attrs.id] = { node, pos };
          }

          nodes[node.type.name] = [
            ...(nodes[node.type.name] || []),
            { node, pos },
          ];
        });

        if (oldState) {
          const deletedIds = new Set<string>();
          const changedIds = new Set<string>();
          const addedIds = new Set<string>();

          for (const [id, node] of Object.entries(prevNodesById)) {
            if (nextNodesById[id] === undefined) {
              deletedIds.add(id);
            } else if (node !== nextNodesById[id]) {
              changedIds.add(id);
            }
          }

          for (const [id, node] of Object.entries(nextNodesById)) {
            if (prevNodesById[id] === undefined) {
              addedIds.add(id);
            } else if (node !== prevNodesById[id]) {
              changedIds.add(id);
            }
          }

          console.log({
            deletedIds,
            changedIds,
            addedIds,
          });

          const touchedIds = [...deletedIds, ...changedIds, ...addedIds];

          // Which nodes should be refreshed?
          for (let name of Object.keys(nodes)) {
            if (!options.types.includes(name)) continue;
            let tempNodes = nodes[name];
            refresh[name] = tempNodes.filter((x) =>
              touchedIds.includes(x.node.attrs.id)
            );
          }
        }
      }

      return {
        nodes,
        refresh,
      };
    }

    const plugin = new Plugin<INodesState>({
      key: NODES_PLUGIN_KEY,
      state: {
        init(config, instance) {
          let nodes: Record<string, NodeWithPos[]> = {};
          let refresh: Record<string, NodeWithPos[]> = {};
          let decorations: DecorationSet | null = null;
          return {
            nodes: compareStates(undefined, instance).nodes,
            refresh,
            decorations,
          };
        },
        apply(tr, value, oldState, newState) {
          if (tr.getMeta('refresh')) {
            const node: NodeWithPos = tr.getMeta('refresh');
            const decorations = DecorationSet.create(tr.doc, [
              Decoration.node(
                node.pos,
                node.pos + node.node.nodeSize,
                {},
                { refresh: Math.random() }
              ),
            ]);
            return { ...value, decorations: decorations };
          }

          return {
            ...compareStates(oldState, newState),
            decorations: value.decorations,
          };
        },
      },

      props: {
        decorations(state) {
          const decorations = NODES_PLUGIN_KEY.getState(state)?.decorations;
          return decorations;
        },
      },

      appendTransaction(transactions, oldState, newState) {
        return null;
      },
    });

    return [plugin];
  },

So in both init() and apply() methods I iterate over all nodes in newState and store it in a state, whose other nodes may access. Also, I’m comparing oldState and newState to figure out which nodes or node types I should rerender - this is done by attaching unique id to each such a node to make comparison easier. For each resulting node, I change its decorations to force it to rerender. This is how I figure out which number attach to a node:

    const nodesWithPos =
      NODES_PLUGIN_KEY.getState(this._outerView.state)?.nodes[
        this.node.type.name
      ] || [];
    const nodes = nodesWithPos.map((x) => x.node);
    let number = nodes.indexOf(this.node) + 1;
    this._numberElt?.classList.contains('hidden') &&
      this._numberElt?.classList.remove('hidden');
    this._numberElt!.innerText = '(' + number + ')';

Additionally, I found that I need a method which allows me to refresh a given node, without touching the rest of the doc. I did it in the following way:

    refreshNodeAtPos:
        (node) =>
        ({ editor, commands, view, tr }) => {
          return editor.chain().setMeta('refresh', node).run();
        },

So I save a metadata with the node to refresh, and then my plugin apply a decoration to my node as shown in the plugin apply() method. This is the last plugin in my schema to make sure that newState is the final state. I guess that it will not be the case when other plugins has appendTransaction method right? Unfortunately I didn’t find how to update plugin state from appendTransaction method.

QUESTION @marijn do you see any logic issues in the code I presented? Like performance for example? Any suggestion what could be done better?

Appended transactions will also go through state field apply methods. You won’t miss any state changes when observing transactions that way.

Indeed, the cycle is pluginapplyappendTransaction so I will capture all transactions this way. Below is the updated code - do you see any logic (and performance) isses or do you have some suggestions how to do it in better way? I added

  let node: NodeWithPos =
            tr.getMeta('refresh') ||
            tr.getMeta('appendedTransaction')?.getMeta('refresh');

because it seems that at some later stage the meta does not contains refresh key and I have to search first for appendedTransaction meta.

Here full plugin code:


  addProseMirrorPlugins() {
    const { editor, options } = this;

    function groupNodes(state: EditorState | undefined) {
      const nodes: Record<string, NodeWithPos[]> = {};
      const nodesById: Record<string, NodeWithPos> = {};

      if (!state) return { nodes, nodesById };

      state.doc.descendants((node, pos) => {
        if (node.attrs.id) {
          nodesById[node.attrs.id] = { node, pos };
        }
        nodes[node.type.name] = [
          ...(nodes[node.type.name] || []),
          { node, pos },
        ];
      });
      return { nodes, nodesById };
    }

    function compareStates(
      oldState: EditorState | undefined,
      newState: EditorState
    ): Omit<INodesState, 'decorations'> {
      let refresh: Record<string, NodeWithPos[]> = {};
      const { nodes: prevNodes, nodesById: prevNodesById } =
        groupNodes(oldState);
      const { nodes: nodes, nodesById: nextNodesById } = groupNodes(newState);

      // Figure out which nodes has changed and should be refreshed
      if (oldState) {
        const deletedIds = new Set<string>();
        const changedIds = new Set<string>();
        const addedIds = new Set<string>();

        for (const [id, node] of Object.entries(prevNodesById)) {
          if (nextNodesById[id] === undefined) {
            deletedIds.add(id);
          } else if (node.node !== nextNodesById[id].node) {
            changedIds.add(id);
          }
        }

        for (const [id, node] of Object.entries(nextNodesById)) {
          if (prevNodesById[id] === undefined) {
            addedIds.add(id);
          } else if (node.node !== prevNodesById[id].node) {
            changedIds.add(id);
          }
        }

        console.log({
          deletedIds,
          changedIds,
          addedIds,
        });

        const touchedIds = [...deletedIds, ...changedIds, ...addedIds];

        for (let name of Object.keys(nodes)) {
          if (!options.types.includes(name)) continue;
          let tempNodes = nodes[name];
          // Most often we need to refresh all nodes of a given type
          if (
            [
              'math_display',
              'caption',
              'remark',
              'annotation',
            ].includes(name)
          ) {
            refresh[name] = tempNodes;
            continue;
          }
          // But sometimes we might want to refresh only changed nodes
          refresh[name] = tempNodes.filter((x) =>
            touchedIds.includes(x.node.attrs.id)
          );
        }
      }

      return {
        nodes,
        refresh,
      };
    }

    const plugin = new Plugin<INodesState>({
      key: NODES_PLUGIN_KEY,
      state: {
        init(config, instance) {
          let refresh: Record<string, NodeWithPos[]> = {};
          return {
            nodes: groupNodes(instance).nodes,
            refresh,
          };
        },
        apply(tr, value, oldState, newState) {
          let node: NodeWithPos =
            tr.getMeta('refresh') ||
            tr.getMeta('appendedTransaction')?.getMeta('refresh');

          if (node) {
            let refresh: Record<string, NodeWithPos[]> = {};
            refresh[node.node.type.name] = [node];
            return { ...value, refresh: refresh };
          }

          return {
            ...compareStates(oldState, newState),
          };
        },
      },

      props: {
        decorations(state) {
          const refresh = NODES_PLUGIN_KEY.getState(state)?.refresh;
          if (!refresh) return null;

          let decorations = [];
          for (let name of Object.keys(refresh)) {
            let decos = refresh[name].map((node) =>
              Decoration.node(
                node.pos,
                node.pos + node.node.nodeSize,
                {},
                { refresh: Math.random() }
              )
            );
            decorations.push(...decos);
          }

          return DecorationSet.create(state.doc, decorations);
        },
      },
    });

    return [plugin];
  },