Python Visualization

Following are files to support the Python visualization.

smoothing_types.py

r"""This module, smoothing_types.py, defines types used for smoothing
hexahedral meshes.
"""

from enum import Enum
from typing import NamedTuple


class Vertex(NamedTuple):
    """A general 3D vertex with x, y, and z coordinates."""

    x: float
    y: float
    z: float


class Hierarchy(Enum):
    """All nodes must be categorized as beloning to one, and only one,
    of the following hierarchical categories.
    """

    INTERIOR = 0
    BOUNDARY = 1
    PRESCRIBED = 2


Vertices = tuple[Vertex, ...]
Hex = tuple[int, int, int, int, int, int, int, int]  # only hex elements
Hexes = tuple[Hex, ...]
Neighbor = tuple[int, ...]
Neighbors = tuple[Neighbor, ...]
NodeHierarchy = tuple[Hierarchy, ...]
PrescribedNodes = tuple[tuple[int, Vertex], ...] | None


class SmoothingAlgorithm(Enum):
    """The type of smoothing algorithm."""

    LAPLACE = "Laplace"
    TAUBIN = "Taubin"


class SmoothingExample(NamedTuple):
    """The prototype smoothing example."""

    vertices: Vertices
    elements: Hexes
    nelx: int
    nely: int
    nelz: int
    # neighbors: Neighbors
    node_hierarchy: NodeHierarchy
    prescribed_nodes: PrescribedNodes
    scale_lambda: float
    scale_mu: float
    num_iters: int
    algorithm: SmoothingAlgorithm
    file_stem: str

smoothing_test.py

r"""This module, smoothing_test.py, tests the smoothing modules.

Example:
--------
source ~/autotwin/automesh/.venv/bin/activate
cd ~/autotwin/automesh/book/examples/smoothing
python -m pytest smoothing_test.py

Reference:
----------
DoubleX unit test
https://autotwin.github.io/automesh/examples/unit_tests/index.html#double-x
"""

from typing import Final

# import sandbox.smoothing as sm
# import sandbox.smoothing_types as ty
import smoothing as sm
import smoothing_examples as examples
import smoothing_types as ty

# Type alias for functional style methods
# https://docs.python.org/3/library/typing.html#type-aliases
Hexes = ty.Hexes
Hierarchy = ty.Hierarchy
Neighbors = ty.Neighbors
NodeHierarchy = ty.NodeHierarchy
Vertex = ty.Vertex
Vertices = ty.Vertices
SmoothingAlgorithm = ty.SmoothingAlgorithm


def test_average_position():
    """Unit test for average_position"""
    v1 = Vertex(x=1.0, y=2.0, z=3.0)
    v2 = Vertex(x=4.0, y=5.0, z=6.0)
    v3 = Vertex(x=7.0, y=8.0, z=9.0)

    v_ave = sm.average_position((v1, v2, v3))
    assert v_ave.x == 4.0
    assert v_ave.y == 5.0
    assert v_ave.z == 6.0

    # docstring example
    v1, v2 = Vertex(1, 2, 3), Vertex(4, 5, 6)
    assert sm.average_position((v1, v2)) == Vertex(2.5, 3.5, 4.5)


def test_add():
    """Unit test for the addition of Vertex v1 and Vertex v2."""
    v1 = Vertex(x=1.0, y=2.0, z=3.0)
    v2 = Vertex(x=4.0, y=7.0, z=1.0)
    vv = sm.add(v1=v1, v2=v2)
    assert vv.x == 5.0
    assert vv.y == 9.0
    assert vv.z == 4.0

    # docstring example
    v1, v2 = Vertex(1, 2, 3), Vertex(4, 5, 6)
    assert sm.add(v1, v2) == Vertex(5, 7, 9)


def test_subtract():
    """Unit test for the subtraction of Vertex v2 from Vertex v1."""
    v1 = Vertex(x=1.0, y=2.0, z=3.0)
    v2 = Vertex(x=4.0, y=7.0, z=1.0)
    vv = sm.subtract(v1=v1, v2=v2)
    assert vv.x == -3.0
    assert vv.y == -5.0
    assert vv.z == 2.0

    # docstring example
    v1, v2 = Vertex(8, 5, 2), Vertex(1, 2, 3)
    assert sm.subtract(v1, v2) == Vertex(7, 3, -1)


def test_scale():
    """Unit test for the scale function."""
    v1 = Vertex(x=1.0, y=2.0, z=3.0)
    ss = 10.0
    result = sm.scale(vertex=v1, scale_factor=ss)
    assert result.x == 10.0
    assert result.y == 20.0
    assert result.z == 30.0

    # docstring example
    v = Vertex(1, 2, 3)
    scale_factor = 2
    assert sm.scale(v, scale_factor) == Vertex(2, 4, 6)


def test_xyz():
    """Unit test to assure the (x, y, z) coordinate tuple is returned
    correctly.
    """
    vv = Vertex(x=1.1, y=2.2, z=3.3)
    gold = (1.1, 2.2, 3.3)
    result = sm.xyz(vv)
    assert result == gold

    # docstring example
    v = Vertex(1, 2, 3)
    assert sm.xyz(v) == (1, 2, 3)


def test_smoothing_neighbors():
    """Given the Double X test problem with completely made up
    node hierarchy, assure that `smoothing_neighbors` returns
    the correct neighbors.
    """
    ex = examples.double_x
    # neighbors = ex.neighbors  # borrow the neighbor connections
    neighbors = sm.node_node_connectivity(ex.elements)

    node_hierarchy = (
        Hierarchy.INTERIOR,
        Hierarchy.BOUNDARY,
        Hierarchy.PRESCRIBED,
        Hierarchy.PRESCRIBED,
        Hierarchy.BOUNDARY,
        Hierarchy.INTERIOR,
        Hierarchy.INTERIOR,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.INTERIOR,
        Hierarchy.INTERIOR,
        Hierarchy.INTERIOR,
    )

    result = sm.smoothing_neighbors(
        neighbors=neighbors, node_hierarchy=node_hierarchy
    )
    gold_smoothing_neighbors = (
        (2, 4, 7),
        (3, 5, 8),
        (),
        (),
        (2, 4),
        (3, 5, 12),
        (1, 8, 10),
        (2, 9),
        (3, 8),
        (4, 7, 11),
        (5, 8, 10, 12),
        (6, 9, 11),
    )

    assert result == gold_smoothing_neighbors

    # doctring example
    neighbors = ((2, 3), (1, 4), (1, 5), (2, 6), (3,), (4,))
    node_hierarchy = (
        Hierarchy.INTERIOR,
        Hierarchy.BOUNDARY,
        Hierarchy.PRESCRIBED,
        Hierarchy.BOUNDARY,
        Hierarchy.INTERIOR,
        Hierarchy.INTERIOR,
    )
    gold = ((2, 3), (4,), (), (2,), (3,), (4,))
    assert sm.smoothing_neighbors(neighbors, node_hierarchy) == gold


def test_laplace_hierarchical_bracket():
    """Unit test for Laplace smoothing with hierarhical control
    on the Bracket example."""
    bracket = examples.bracket

    node_hierarchy = bracket.node_hierarchy
    # neighbors = bracket.neighbors
    neighbors = sm.node_node_connectivity(bracket.elements)
    node_hierarchy = bracket.node_hierarchy

    # If a node is PRESCRIBED, then it has no smoothing neighbors
    smoothing_neighbors = sm.smoothing_neighbors(
        neighbors=neighbors, node_hierarchy=node_hierarchy
    )
    gold_smoothing_neighbors = (
        (),  # 1
        (),  # 2
        (),  # 3
        (),  # 4
        (),  # 5
        (),  # 6
        (2, 6, 8, 12, 28),  # 7
        (3, 7, 9, 13, 29),  # 8
        (4, 8, 10, 14, 30),  # 9
        (),  # 10
        (),  # 11
        (7, 11, 13, 17, 33),  # 12
        (8, 12, 14, 18, 34),  # 13
        (9, 13, 15, 35),  # 14
        (),  # 15
        (),  # 16
        (12, 16, 18, 20, 38),  # 17
        (13, 17, 21, 39),  # 18
        (),  # 19
        (),  # 20
        (),
        (),  # 22
        (),
        (),  # 24
        (),
        (),  # 26
        (),
        (7, 23, 27, 29, 33),  # 28
        (8, 24, 28, 30, 34),  # 29
        (9, 25, 29, 31, 35),  # 30
        (),  # 31
        (),  # 32
        (12, 28, 32, 34, 38),  # 33
        (13, 29, 33, 35, 39),  # 34
        (14, 30, 34, 36),  # 35
        (),  # 36
        (),  # 37
        (17, 33, 37, 39, 41),  # 38
        (18, 34, 38, 42),  # 39
        (),  # 40
        (),  # 41
        (),  # 42
    )

    assert smoothing_neighbors == gold_smoothing_neighbors

    # specific test with lambda = 0.3 and num_iters = 10
    scale_lambda_test = 0.3
    num_iters_test = 10

    result = sm.smooth(
        vv=bracket.vertices,
        hexes=bracket.elements,
        node_hierarchy=bracket.node_hierarchy,
        prescribed_nodes=bracket.prescribed_nodes,
        scale_lambda=scale_lambda_test,
        num_iters=num_iters_test,
        algorithm=bracket.algorithm,
    )

    gold_vertices_10_iter = (
        Vertex(x=0, y=0, z=0),
        Vertex(x=1, y=0, z=0),
        Vertex(x=2, y=0, z=0),
        Vertex(x=3, y=0, z=0),
        Vertex(x=4, y=0, z=0),
        Vertex(x=0, y=1, z=0),
        Vertex(
            x=0.9974824535030984, y=0.9974824535030984, z=0.24593434133370803
        ),
        Vertex(
            x=1.9620726956646117, y=1.0109475009958278, z=0.2837944855813176
        ),
        Vertex(
            x=2.848322987789396, y=1.1190213008349328, z=0.24898414051620496
        ),
        Vertex(x=3.695518130045147, y=1.5307337294603591, z=0),
        Vertex(x=0, y=2, z=0),
        Vertex(
            x=1.0109475009958275, y=1.9620726956646117, z=0.2837944855813176
        ),
        Vertex(
            x=1.9144176939366933, y=1.9144176939366933, z=0.3332231502067546
        ),
        Vertex(
            x=2.5912759493290007, y=1.961874667390146, z=0.29909606343914835
        ),
        Vertex(x=2.8284271247461903, y=2.82842712474619, z=0),
        Vertex(x=0, y=3, z=0),
        Vertex(
            x=1.119021300834933, y=2.848322987789396, z=0.24898414051620493
        ),
        Vertex(
            x=1.9618746673901462, y=2.5912759493290007, z=0.29909606343914835
        ),
        Vertex(x=0, y=4, z=0),
        Vertex(x=1.5307337294603593, y=3.695518130045147, z=0),
        Vertex(x=2.8284271247461903, y=2.82842712474619, z=0),
        Vertex(x=0, y=0, z=1),
        Vertex(x=1, y=0, z=1),
        Vertex(x=2, y=0, z=1),
        Vertex(x=3, y=0, z=1),
        Vertex(x=4, y=0, z=1),
        Vertex(x=0, y=1, z=1),
        Vertex(
            x=0.9974824535030984, y=0.9974824535030984, z=0.7540656586662919
        ),
        Vertex(
            x=1.9620726956646117, y=1.0109475009958278, z=0.7162055144186824
        ),
        Vertex(x=2.848322987789396, y=1.119021300834933, z=0.7510158594837951),
        Vertex(x=3.695518130045147, y=1.5307337294603591, z=1),
        Vertex(x=0, y=2, z=1),
        Vertex(
            x=1.0109475009958275, y=1.9620726956646117, z=0.7162055144186824
        ),
        Vertex(
            x=1.9144176939366933, y=1.9144176939366933, z=0.6667768497932453
        ),
        Vertex(
            x=2.591275949329001, y=1.9618746673901462, z=0.7009039365608517
        ),
        Vertex(x=2.8284271247461903, y=2.82842712474619, z=1),
        Vertex(x=0, y=3, z=1),
        Vertex(x=1.1190213008349328, y=2.848322987789396, z=0.751015859483795),
        Vertex(
            x=1.9618746673901462, y=2.5912759493290007, z=0.7009039365608516
        ),
        Vertex(x=0, y=4, z=1),
        Vertex(x=1.5307337294603593, y=3.695518130045147, z=1),
        Vertex(x=2.8284271247461903, y=2.82842712474619, z=1),
    )

    assert result == gold_vertices_10_iter


def test_laplace_smoothing_double_x():
    """Unit test for Laplace smoothing with all dofs as BOUNDARY
    on the Double X example."""
    vv: Vertices = (
        Vertex(0.0, 0.0, 0.0),
        Vertex(1.0, 0.0, 0.0),
        Vertex(2.0, 0.0, 0.0),
        Vertex(0.0, 1.0, 0.0),
        Vertex(1.0, 1.0, 0.0),
        Vertex(2.0, 1.0, 0.0),
        Vertex(0.0, 0.0, 1.0),
        Vertex(1.0, 0.0, 1.0),
        Vertex(2.0, 0.0, 1.0),
        Vertex(0.0, 1.0, 1.0),
        Vertex(1.0, 1.0, 1.0),
        Vertex(2.0, 1.0, 1.0),
    )

    hexes: Hexes = (
        (1, 2, 5, 4, 7, 8, 11, 10),
        (2, 3, 6, 5, 8, 9, 12, 11),
    )

    # nn: Neighbors = (
    #     (2, 4, 7),
    #     (1, 3, 5, 8),
    #     (2, 6, 9),
    #     (1, 5, 10),
    #     (2, 4, 6, 11),
    #     (3, 5, 12),
    #     (1, 8, 10),
    #     (2, 7, 9, 11),
    #     (3, 8, 12),
    #     (4, 7, 11),
    #     (5, 8, 10, 12),
    #     (6, 9, 11),
    # )

    nh: NodeHierarchy = (
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
    )

    scale_lambda: Final[float] = 0.3  # lambda for Laplace smoothing

    # iteration 1
    num_iters = 1  # single iteration of smoothing

    algo = SmoothingAlgorithm.LAPLACE

    aa = sm.smooth(
        vv=vv,
        hexes=hexes,
        node_hierarchy=nh,
        prescribed_nodes=None,
        scale_lambda=scale_lambda,
        num_iters=num_iters,
        algorithm=algo,
    )
    cc: Final[float] = scale_lambda / 3.0  # delta corner
    ee: Final[float] = scale_lambda / 4.0  # delta edge
    # define the gold standard fiducial
    gold = (
        Vertex(x=cc, y=cc, z=cc),  # node 1, corner
        Vertex(x=1.0, y=ee, z=ee),  # node 2, edge
        Vertex(x=2.0 - cc, y=cc, z=cc),  # node 3, corner
        #
        Vertex(x=cc, y=1.0 - cc, z=cc),  # node 4, corner
        Vertex(x=1.0, y=1.0 - ee, z=ee),  # node 5, edge
        Vertex(x=2.0 - cc, y=1.0 - cc, z=cc),  # node 6, corner
        #
        Vertex(x=cc, y=cc, z=1 - cc),  # node 7, corner
        Vertex(x=1.0, y=ee, z=1 - ee),  # node 8, edge
        Vertex(x=2.0 - cc, y=cc, z=1 - cc),  # node 9, corner
        #
        Vertex(x=cc, y=1.0 - cc, z=1 - cc),  # node 10, corner
        Vertex(x=1.0, y=1.0 - ee, z=1 - ee),  # node 11, edge
        Vertex(x=2.0 - cc, y=1.0 - cc, z=1 - cc),  # node 12, corner
    )
    assert aa == gold

    # iteration 2
    num_iters = 2  # overwrite, double iteration of smoothing

    aa2 = sm.smooth(
        vv=vv,
        hexes=hexes,
        node_hierarchy=nh,
        prescribed_nodes=None,
        scale_lambda=scale_lambda,
        num_iters=num_iters,
        algorithm=algo,
    )
    # define the gold standard fiducial
    gold2 = (
        (0.19, 0.1775, 0.1775),
        (1.0, 0.1425, 0.1425),
        (1.8099999999999998, 0.1775, 0.1775),
        (0.19, 0.8225, 0.1775),
        (1.0, 0.8575, 0.1425),
        (1.8099999999999998, 0.8225, 0.1775),
        (0.19, 0.1775, 0.8225),
        (1.0, 0.1425, 0.8575),
        (1.8099999999999998, 0.1775, 0.8225),
        (0.19, 0.8225, 0.8225),
        (1.0, 0.8575, 0.8575),
        (1.8099999999999998, 0.8225, 0.8225),
    )
    assert aa2 == gold2


def test_pair_ordered():
    """Unit test for pair ordered."""

    # small toy example
    given = ((3, 1), (2, 1))
    found = sm.pair_ordered(given)
    gold = ((1, 2), (1, 3))
    assert found == gold

    # example from 12 edges of a hex element
    given = (
        (1, 2),
        (2, 5),
        (4, 1),
        (5, 4),
        (7, 8),
        (8, 11),
        (11, 10),
        (10, 7),
        (1, 7),
        (2, 8),
        (5, 11),
        (4, 10),
    )  # overwrite
    gold = (
        (1, 2),
        (1, 4),
        (1, 7),
        (2, 5),
        (2, 8),
        (4, 5),
        (4, 10),
        (5, 11),
        (7, 8),
        (7, 10),
        (8, 11),
        (10, 11),
    )  # overwrite
    found = sm.pair_ordered(given)  # overwrite
    assert found == gold

    # docstring example
    pairs = ((3, 1), (2, 4), (5, 0))
    assert sm.pair_ordered(pairs) == ((0, 5), (1, 3), (2, 4))


def test_edge_pairs():
    """Units test to assure edge pairs are computed correctly."""
    elements = (
        (1, 2, 5, 4, 7, 8, 11, 10),
        (2, 3, 6, 5, 8, 9, 12, 11),
    )
    found = sm.edge_pairs(hexes=elements)
    gold = (
        (1, 2),
        (1, 4),
        (1, 7),
        (2, 3),
        (2, 5),
        (2, 8),
        (3, 6),
        (3, 9),
        (4, 5),
        (4, 10),
        (5, 6),
        (5, 11),
        (6, 12),
        (7, 8),
        (7, 10),
        (8, 9),
        (8, 11),
        (9, 12),
        (10, 11),
        (11, 12),
    )
    assert found == gold


def test_node_node_connectivity():
    """Tests that the node_node_connectivity function is properly
    implemented.
    """

    # from the Double X unit test

    hexes = (
        (1, 2, 5, 4, 7, 8, 11, 10),
        (2, 3, 6, 5, 8, 9, 12, 11),
    )

    gold_neighbors = (
        (2, 4, 7),
        (1, 3, 5, 8),
        (2, 6, 9),
        (1, 5, 10),
        (2, 4, 6, 11),
        (3, 5, 12),
        (1, 8, 10),
        (2, 7, 9, 11),
        (3, 8, 12),
        (4, 7, 11),
        (5, 8, 10, 12),
        (6, 9, 11),
    )

    result = sm.node_node_connectivity(hexes)

    assert gold_neighbors == result

    # now with node number modifications to assure the
    # algorithm does not assume sequential node numbers:
    # 2 -> 22
    # 5 -> 55
    # 8 -> 88
    # 11 -> 111
    hexes_2 = (
        (1, 22, 55, 4, 7, 88, 111, 10),
        (22, 3, 6, 55, 88, 9, 12, 111),
    )

    gold_neighbors_2 = (
        (4, 7, 22),  # 1
        (6, 9, 22),  # 3
        (1, 10, 55),  # 4
        (3, 12, 55),  # 6
        (1, 10, 88),  # 7
        (3, 12, 88),  # 9
        (4, 7, 111),  # 10
        (6, 9, 111),  # 12
        (1, 3, 55, 88),  # 2 -> 22
        (4, 6, 22, 111),  # 5 -> 55
        (7, 9, 22, 111),  # 8 -> 88
        (10, 12, 55, 88),  # 11 -> 111
    )

    result_2 = sm.node_node_connectivity(hexes_2)

    assert gold_neighbors_2 == result_2

    # example from the L-bracket example
    hexes_bracket = (
        (1, 2, 7, 6, 22, 23, 28, 27),
        (2, 3, 8, 7, 23, 24, 29, 28),
        (3, 4, 9, 8, 24, 25, 30, 29),
        (4, 5, 10, 9, 25, 26, 31, 30),
        (6, 7, 12, 11, 27, 28, 33, 32),
        (7, 8, 13, 12, 28, 29, 34, 33),
        (8, 9, 14, 13, 29, 30, 35, 34),
        (9, 10, 15, 14, 30, 31, 36, 35),
        (11, 12, 17, 16, 32, 33, 38, 37),
        (12, 13, 18, 17, 33, 34, 39, 38),
        (16, 17, 20, 19, 37, 38, 41, 40),
        (17, 18, 21, 20, 38, 39, 42, 41),
    )

    gold_neighbors_bracket = (
        (2, 6, 22),
        (1, 3, 7, 23),
        (2, 4, 8, 24),
        (3, 5, 9, 25),
        (4, 10, 26),
        #
        (1, 7, 11, 27),
        (2, 6, 8, 12, 28),
        (3, 7, 9, 13, 29),
        (4, 8, 10, 14, 30),
        (5, 9, 15, 31),
        #
        (6, 12, 16, 32),
        (7, 11, 13, 17, 33),
        (8, 12, 14, 18, 34),
        (9, 13, 15, 35),
        (10, 14, 36),
        #
        (11, 17, 19, 37),
        (12, 16, 18, 20, 38),
        (13, 17, 21, 39),
        #
        (16, 20, 40),
        (17, 19, 21, 41),
        (18, 20, 42),
        # top layer
        (1, 23, 27),
        (2, 22, 24, 28),
        (3, 23, 25, 29),
        (4, 24, 26, 30),
        (5, 25, 31),
        #
        (6, 22, 28, 32),
        (7, 23, 27, 29, 33),
        (8, 24, 28, 30, 34),
        (9, 25, 29, 31, 35),
        (10, 26, 30, 36),
        #
        (11, 27, 33, 37),
        (12, 28, 32, 34, 38),
        (13, 29, 33, 35, 39),
        (14, 30, 34, 36),
        (15, 31, 35),
        #
        (16, 32, 38, 40),
        (17, 33, 37, 39, 41),
        (18, 34, 38, 42),
        #
        (19, 37, 41),
        (20, 38, 40, 42),
        (21, 39, 41),
    )

    result_bracket = sm.node_node_connectivity(hexes_bracket)

    assert gold_neighbors_bracket == result_bracket

smoothing.py

r"""This module, smoothing.py, contains the core computations for
smoothing algorithms.
"""

# import sandbox.smoothing_types as ty
import smoothing_types as ty


# Type alias for functional style methods
# https://docs.python.org/3/library/typing.html#type-aliases
Hexes = ty.Hexes
Hierarchy = ty.Hierarchy
Neighbors = ty.Neighbors
NodeHierarchy = ty.NodeHierarchy
PrescribedNodes = ty.PrescribedNodes
Vertex = ty.Vertex
Vertices = ty.Vertices
SmoothingAlgorithm = ty.SmoothingAlgorithm


def average_position(vertices: Vertices) -> Vertex:
    """Calculate the average position of a list of vertices.

    This function computes the average coordinates (x, y, z) of a given
    list of Vertex objects. It raises an assertion error if the input
    list is empty.

    Parameters:
    vertices (Vertices): A list or collection of Vertex objects, where
                         each Vertex has x, y, and z attributes
                         representing its coordinates in 3D space.

    Returns:
    Vertex: A new Vertex object representing the average position of the
            input vertices, with x, y, and z attributes set to the
            average coordinates.

    Raises:
    AssertionError: If the number of vertices is zero, indicating that
                    the input list must contain at least one vertex.

    Example:
    >>> v1 = Vertex(1, 2, 3)
    >>> v2 = Vertex(4, 5, 6)
    >>> average_position([v1, v2])
    Vertex(x=2.5, y=3.5, z=4.5)
    """

    n_vertices = len(vertices)
    assert n_vertices > 0, "Error: number of vertices must be positive."
    xs = [v.x for v in vertices]
    ys = [v.y for v in vertices]
    zs = [v.z for v in vertices]
    x_ave = sum(xs) / n_vertices
    y_ave = sum(ys) / n_vertices
    z_ave = sum(zs) / n_vertices

    return Vertex(x=x_ave, y=y_ave, z=z_ave)


def add(v1: Vertex, v2: Vertex) -> Vertex:
    """
    Add two Vertex objects component-wise.

    This function takes two Vertex objects and returns a new Vertex
    object that represents the component-wise addition of the two
    input vertices.

    Parameters:
    v1 (Vertex): The first Vertex object to be added.
    v2 (Vertex): The second Vertex object to be added.

    Returns:
    Vertex: A new Vertex object representing the result of the addition,
            with x, y, and z attributes set to the sum of the corresponding
            attributes of v1 and v2.

    Example:
    >>> v1 = Vertex(1, 2, 3)
    >>> v2 = Vertex(4, 5, 6)
    >>> add(v1, v2)
    Vertex(x=5, y=7, z=9)
    """
    dx = v1.x + v2.x
    dy = v1.y + v2.y
    dz = v1.z + v2.z
    return Vertex(x=dx, y=dy, z=dz)


def subtract(v1: Vertex, v2: Vertex) -> Vertex:
    """
    Subtract one Vertex object from another component-wise.

    This function takes two Vertex objects and returns a new Vertex
    object that represents the component-wise subtraction of the second
    vertex from the first.

    Parameters:
    v1 (Vertex): The Vertex object from which v2 will be subtracted.
    v2 (Vertex): The Vertex object to be subtracted from v1.

    Returns:
    Vertex: A new Vertex object representing the result of the subtraction,
            (v1 - v2), with x, y, and z attributes set to the difference
            of the corresponding attributes of v1 and v2.

    Example:
    >>> v1 = Vertex(8, 5, 2)
    >>> v2 = Vertex(1, 2, 3)
    >>> subtract(v1, v2)
    Vertex(x=7, y=3, z=-1)
    """
    dx = v1.x - v2.x
    dy = v1.y - v2.y
    dz = v1.z - v2.z
    return Vertex(x=dx, y=dy, z=dz)


def scale(vertex: Vertex, scale_factor: float) -> Vertex:
    """
    Scale a Vertex object by a given scale factor.

    This function takes a Vertex object and a scale factor, and returns
    a new Vertex object that represents the original vertex scaled by
    the specified factor.

    Parameters:
    vertex (Vertex): The Vertex object to be scaled.
    scale_factor (float): The factor by which to scale the vertex.
                          This can be any real number, including
                          positive, negative, or zero.

    Returns:
    Vertex: A new Vertex object representing the scaled vertex, with
            x, y, and z attributes set to the original coordinates
            multiplied by the scale factor.

    Example:
    >>> v = Vertex(1, 2, 3)
    >>> scale_factor = 2
    >>> scale(v, scale_factor)
    Vertex(x=2, y=4, z=6)
    """
    x = scale_factor * vertex.x
    y = scale_factor * vertex.y
    z = scale_factor * vertex.z
    return Vertex(x=x, y=y, z=z)


def xyz(v1: Vertex) -> tuple[float, float, float]:
    """
    Extract the coordinates of a Vertex object.

    This function takes a Vertex object and returns its coordinates
    as a tuple in the form (x, y, z).

    Parameters:
    v1 (Vertex): The Vertex object from which to extract the coordinates.

    Returns:
    tuple[float, float, float]: A tuple containing the x, y, and z
                                 coordinates of the vertex.

    Example:
    >>> v = Vertex(1, 2, 3)
    >>> xyz(v)
    (1, 2, 3)
    """
    aa, bb, cc = v1.x, v1.y, v1.z
    return (aa, bb, cc)


def smoothing_neighbors(neighbors: Neighbors, node_hierarchy: NodeHierarchy):
    """
    Determine the smoothing neighbors for each node based on its
    hierarchy level.

    This function takes an original neighbors structure, which is defined
    by the connectivity of a mesh, and a node hierarchy. It returns a
    subset of the original neighbors that are used for smoothing, based
    on the hierarchy levels of the nodes.

    Parameters:
    neighbors (Neighbors): A structure containing the original neighbors
                           for each node in the mesh.
    node_hierarchy (NodeHierarchy): A structure that defines the hierarchy
                                     levels of the nodes, which can be
                                     INTERIOR, BOUNDARY, or PRESCRIBED.

    Returns:
    tuple: A new structure containing the neighbors used for smoothing,
           which is a subset of the original neighbors based on the
           hierarchy levels.

    Raises:
    ValueError: If a hierarchy value is not in the expected range
                of [INTERIOR, BOUNDAR, PRESCRIBED, or [0, 1, 2],
                respectively.

    Example:
    INTERIOR     PRESCRIBED      INTERIOR
       (1) -------- (3) ----------- (5)
        |
       (2) -------- (4) ----------- (6)
    BOUNDARY     BOUNDARY        INTERIOR

    >>> neighbors = ((2, 3), (1, 4), (1, 5), (2, 6), (3,), (4,))
    >>> node_hierarchy = (Hierarchy.INTERIOR, Hierarchy.BOUNDARY,
                          Hierarchy.PRESCRIBED, Hierarchy.BOUNDARY,
                          Hierarchy.INTERIOR, Hierarchy.INTERIOR)
    >>> smoothing_neighbors(neighbors, node_hierarchy)
    ((2, 3), (4,), (), (2,), (3,), (4,))
    """
    neighbors_new = ()

    for node, level in enumerate(node_hierarchy):
        nei_old = neighbors[node]
        # print(f"Processing node {node+1}, neighbors: {nei_old}")
        levels = [int(node_hierarchy[x - 1].value) for x in nei_old]
        nei_new = ()

        # breakpoint()
        match level:
            case Hierarchy.INTERIOR:
                # print("INTERIOR node")
                nei_new = nei_old
            case Hierarchy.BOUNDARY:
                # print("BOUNDARY node")
                for nn, li in zip(nei_old, levels):
                    if li >= level.value:
                        nei_new += (nn,)
            case Hierarchy.PRESCRIBED:
                # print("PRESCRIBED node")
                nei_new = ()
            case _:
                raise ValueError("Hierarchy value must be in [0, 1, 2]")

        neighbors_new += (nei_new,)

    return neighbors_new


def smooth(
    vv: Vertices,
    hexes: Hexes,
    node_hierarchy: NodeHierarchy,
    prescribed_nodes: PrescribedNodes,
    scale_lambda: float,
    num_iters: int,
    algorithm: SmoothingAlgorithm,
) -> Vertices:
    """
    Given an initial position of vertices, the vertex neighbors,
    and the dof classification of each vertex, perform Laplace
    smoothing for num_iter iterations, and return the updated
    coordinates.
    """
    print(f"Smoothing algorithm: {algorithm.value}")

    assert num_iters >= 1, "`num_iters` must be 1 or greater"

    nn = node_node_connectivity(hexes)

    # if the node_hierarchy contains a Hierarchy.PRESCRIBED type; or
    # the the PrescribedNodes must not be None
    if Hierarchy.PRESCRIBED in node_hierarchy:
        info = "Smoothing algorithm with hierarchical control"
        info += " and PRESCRIBED node positions."
        print(info)
        estr = "Error, NodeHierarchy desigates PRESCRIBED nodes, but no values"
        estr += " for (x, y, z) prescribed positions were given."
        assert prescribed_nodes is not None, estr

        n_nodes_prescribed = node_hierarchy.count(Hierarchy.PRESCRIBED)
        n_prescribed_xyz = len(prescribed_nodes)
        estr = f"Error: number of PRESCRIBED nodes: {n_nodes_prescribed}"
        estr += " must match the number of"
        estr += f" prescribed Vertices(x, y, z): {n_prescribed_xyz}"
        assert n_nodes_prescribed == n_prescribed_xyz, estr

        # update neighbors
        nn = smoothing_neighbors(
            neighbors=nn, node_hierarchy=node_hierarchy
        )  # overwrite

        # update vertex positions
        vv_list = list(vv)  # make mutable
        for node_id, node_xyz in prescribed_nodes:
            # print(f"Update node {node_id}")
            # print(f"  from {vv_list[node_id-1]}")
            # print(f"  to {node_xyz}")
            vv_list[node_id - 1] = node_xyz  # zero index, overwrite xyz

        # revert to immutable
        vv = tuple(vv_list)  # overwrite

    vertices_old = vv

    # breakpoint()
    for k in range(num_iters):

        print(f"Iteration: {k+1}")
        vertices_new = []

        for vertex, neighbors in zip(vertices_old, nn):
            # debug vertex by vertex
            # print(f"vertex {vertex}, neighbors {neighbors}")

            # account for zero-index instead of 1-index:
            neighbor_vertices = tuple(
                vertices_old[i - 1] for i in neighbors
            )  # zero index

            if len(neighbor_vertices) > 0:
                neighbor_average = average_position(neighbor_vertices)
                delta = subtract(v1=neighbor_average, v2=vertex)
                lambda_delta = scale(vertex=delta, scale_factor=scale_lambda)
                vertex_new = add(v1=vertex, v2=lambda_delta)
            elif len(neighbor_vertices) == 0:
                # print("Prescribed node, no smoothing update.")
                vertex_new = vertex
            else:
                estr = "Error: neighbor_vertices negative length"
                raise ValueError(estr)

            vertices_new.append(vertex_new)
            # breakpoint()

        vertices_old = vertices_new  # overwrite for new k loop

    # breakpoint()
    return tuple(vertices_new)


def pair_ordered(ab: tuple[tuple[int, int], ...]) -> tuple:
    """
    Order pairs of integers based on their values.

    Given a tuple of pairs in the form ((a, b), (c, d), ...), this
    function orders each pair such that the smaller integer comes
    first. It then sorts the resulting pairs primarily by the first
    element and secondarily by the second element.

    Parameters:
    ab (tuple[tuple[int, int], ...]): A tuple containing pairs of integers.

    Returns:
    tuple: A new tuple containing the ordered pairs, where each pair
           is sorted internally and the entire collection is sorted
           based on the first and second elements.

    Example:
    >>> pairs = ((3, 1), (2, 4), (5, 0))
    >>> pair_ordered(pairs)
    ((0, 5), (1, 3), (2, 4))
    """
    firsts, seconds = zip(*ab)

    ab_ordered = ()

    for a, b in zip(firsts, seconds):
        if a < b:
            ab_ordered += ((a, b),)
        else:
            ab_ordered += ((b, a),)

    # for a in firsts:
    #     print(f"a = {a}")

    # for b in seconds:
    #     print(f"b = {b}")

    result = tuple(sorted(ab_ordered))
    return result


def edge_pairs(hexes: Hexes):
    """
    Extract unique edge pairs from hex element connectivity.

    This function takes a collection of hex elements and returns all
    unique line pairs that represent the edges of the hex elements.
    The edges are derived from the connectivity of the hex elements,
    including both the horizontal edges (bottom and top faces) and
    the vertical edges.

    Used for drawing edges of finite elements.

    Parameters:
    hexes (Hexes): A collection of hex elements, where each hex is
                   represented by a tuple of vertex indices.

    Returns:
    tuple: A sorted tuple of unique edge pairs, where each pair is
           represented as a tuple of two vertex indices.
    """
    pairs = ()
    for ee in hexes:
        # bottom_face = tuple(sorted(list(zip(ee[0:4], ee[1:4] + (ee[0],)))))
        bottom_face = pair_ordered(tuple(zip(ee[0:4], ee[1:4] + (ee[0],))))
        # top_face = tuple(list(zip(ee[4:8], ee[5:8] + (ee[4],))))
        top_face = pair_ordered(tuple(zip(ee[4:8], ee[5:8] + (ee[4],))))
        verticals = pair_ordered(
            (
                (ee[0], ee[4]),
                (ee[1], ee[5]),
                (ee[2], ee[6]),
                (ee[3], ee[7]),
            )
        )
        t3 = bottom_face + top_face + verticals
        pairs = pairs + tuple(t3)
        # breakpoint()

    return tuple(sorted(set(pairs)))


def node_node_connectivity(hexes: Hexes) -> Neighbors:
    """
    Determine the connectivity of nodes to other nodes from
    a list of hexahedral elements.

    This function takes a list of hexahedral elements and returns a
    list of nodes connected to each node based on the edges define
    by the hexahedral elements. Each node's connectivity is represented
    as a tuple of neighboring nodes.

    Parameters:
    hexes (Hexes): A collection of hexahedral elements, where each
                   element is represented by a tuple of node indices.

    Returns:
    Neighbors: A tuple of tuples, where each inner tuple contains the
               indices of nodes connected to the corresponding node
               in the input list.
    """

    # create an empty dictionary from the node numbers
    edict = {item: () for sublist in hexes for item in sublist}

    ep = edge_pairs(hexes)

    for edge in ep:
        aa, bb = edge
        # existing value at edict[a] is a_old
        a_old = edict[aa]
        # existing value at edict[b] is b_old
        b_old = edict[bb]

        # new value
        a_new = (bb,)
        b_new = (aa,)

        # update dictionary
        edict[aa] = a_old + a_new
        edict[bb] = b_old + b_new

    # create a new dictionary, sorted by keys
    sorted_edict = dict(sorted(edict.items()))
    neighbors = tuple(sorted_edict.values())
    return neighbors

smoothing_examples.py

r"""This module, smoothing_examples.py contains data for the
smoothing examples.
"""

import math
from typing import Final

import smoothing_types as ty

# Type alias for functional style methods
# https://docs.python.org/3/library/typing.html#type-aliases
Hierarchy = ty.Hierarchy
SmoothingAlgorithm = ty.SmoothingAlgorithm
Example = ty.SmoothingExample
Vertex = ty.Vertex

DEG2RAD: Final[float] = math.pi / 180.0  # rad/deg

# L-bracket example
bracket = Example(
    vertices=(
        Vertex(0, 0, 0),
        Vertex(1, 0, 0),
        Vertex(2, 0, 0),
        Vertex(3, 0, 0),
        Vertex(4, 0, 0),
        Vertex(0, 1, 0),
        Vertex(1, 1, 0),
        Vertex(2, 1, 0),
        Vertex(3, 1, 0),
        Vertex(4, 1, 0),
        Vertex(0, 2, 0),
        Vertex(1, 2, 0),
        Vertex(2, 2, 0),
        Vertex(3, 2, 0),
        Vertex(4, 2, 0),
        Vertex(0, 3, 0),
        Vertex(1, 3, 0),
        Vertex(2, 3, 0),
        Vertex(0, 4, 0),
        Vertex(1, 4, 0),
        Vertex(2, 4, 0),
        Vertex(0, 0, 1),
        Vertex(1, 0, 1),
        Vertex(2, 0, 1),
        Vertex(3, 0, 1),
        Vertex(4, 0, 1),
        Vertex(0, 1, 1),
        Vertex(1, 1, 1),
        Vertex(2, 1, 1),
        Vertex(3, 1, 1),
        Vertex(4, 1, 1),
        Vertex(0, 2, 1),
        Vertex(1, 2, 1),
        Vertex(2, 2, 1),
        Vertex(3, 2, 1),
        Vertex(4, 2, 1),
        Vertex(0, 3, 1),
        Vertex(1, 3, 1),
        Vertex(2, 3, 1),
        Vertex(0, 4, 1),
        Vertex(1, 4, 1),
        Vertex(2, 4, 1),
    ),
    elements=(
        (1, 2, 7, 6, 22, 23, 28, 27),
        (2, 3, 8, 7, 23, 24, 29, 28),
        (3, 4, 9, 8, 24, 25, 30, 29),
        (4, 5, 10, 9, 25, 26, 31, 30),
        (6, 7, 12, 11, 27, 28, 33, 32),
        (7, 8, 13, 12, 28, 29, 34, 33),
        (8, 9, 14, 13, 29, 30, 35, 34),
        (9, 10, 15, 14, 30, 31, 36, 35),
        (11, 12, 17, 16, 32, 33, 38, 37),
        (12, 13, 18, 17, 33, 34, 39, 38),
        (16, 17, 20, 19, 37, 38, 41, 40),
        (17, 18, 21, 20, 38, 39, 42, 41),
    ),
    nelx=4,
    nely=4,
    nelz=1,
    # neighbors=(
    #     (2, 6, 22),
    #     (1, 3, 7, 23),
    #     (2, 4, 8, 24),
    #     (3, 5, 9, 25),
    #     (4, 10, 26),
    #     #
    #     (1, 7, 11, 27),
    #     (2, 6, 8, 12, 28),
    #     (3, 7, 9, 13, 29),
    #     (4, 8, 10, 14, 30),
    #     (5, 9, 15, 31),
    #     #
    #     (6, 12, 16, 32),
    #     (7, 11, 13, 17, 33),
    #     (8, 12, 14, 18, 34),
    #     (9, 13, 15, 35),
    #     (10, 14, 36),
    #     #
    #     (11, 17, 19, 37),
    #     (12, 16, 18, 20, 38),
    #     (13, 17, 21, 39),
    #     #
    #     (16, 20, 40),
    #     (17, 19, 21, 41),
    #     (18, 20, 42),
    #     # top layer
    #     (1, 23, 27),
    #     (2, 22, 24, 28),
    #     (3, 23, 25, 29),
    #     (4, 24, 26, 30),
    #     (5, 25, 31),
    #     #
    #     (6, 22, 28, 32),
    #     (7, 23, 27, 29, 33),
    #     (8, 24, 28, 30, 34),
    #     (9, 25, 29, 31, 35),
    #     (10, 26, 30, 36),
    #     #
    #     (11, 27, 33, 37),
    #     (12, 28, 32, 34, 38),
    #     (13, 29, 33, 35, 39),
    #     (14, 30, 34, 36),
    #     (15, 31, 35),
    #     #
    #     (16, 32, 38, 40),
    #     (17, 33, 37, 39, 41),
    #     (18, 34, 38, 42),
    #     #
    #     (19, 37, 41),
    #     (20, 38, 40, 42),
    #     (21, 39, 41),
    # ),
    node_hierarchy=(
        # hierarchy enum, node number, prescribed (x, y, z)
        Hierarchy.PRESCRIBED,  # 1 -> (0, 0, 0)
        Hierarchy.PRESCRIBED,  # 2 -> (1, 0, 0)
        Hierarchy.PRESCRIBED,  # 3 -> (2, 0, 0)
        Hierarchy.PRESCRIBED,  # 4 -> (3, 0, 0)
        Hierarchy.PRESCRIBED,  # 5 -> (4, 0, 0)
        Hierarchy.PRESCRIBED,  # 6 -> (0, 1, 0)
        Hierarchy.BOUNDARY,  # 7
        Hierarchy.BOUNDARY,  # 8
        Hierarchy.BOUNDARY,  # 9
        Hierarchy.PRESCRIBED,  # 10 -> (4.5*cos(15 deg), 4.5*sin(15 deg), 0)
        Hierarchy.PRESCRIBED,  # 11 -> *(0, 2, 0)
        Hierarchy.BOUNDARY,  # 12
        Hierarchy.BOUNDARY,  # 13
        Hierarchy.BOUNDARY,  # 14
        Hierarchy.PRESCRIBED,  # 15 -> (4.5*cos(30 deg), 4.5*sin(30 deg), 0)
        Hierarchy.PRESCRIBED,  # 16 -> (0, 3, 0)
        Hierarchy.BOUNDARY,  # 17
        Hierarchy.BOUNDARY,  # 18
        Hierarchy.PRESCRIBED,  # 19 -> (0, 4, 0)
        Hierarchy.PRESCRIBED,  # 20 -> (1.5, 4, 0)
        Hierarchy.PRESCRIBED,  # 21 -> (3.5, 4, 0)
        #
        Hierarchy.PRESCRIBED,  # 22 -> (0, 0, 1)
        Hierarchy.PRESCRIBED,  # 23 -> (1, 0, 1)
        Hierarchy.PRESCRIBED,  # 24 -> (2, 0, 1)
        Hierarchy.PRESCRIBED,  # 25 -> (3, 0, 1)
        Hierarchy.PRESCRIBED,  # 26 -> (4, 0, 1)
        Hierarchy.PRESCRIBED,  # 27 -> (0, 1, 1)
        Hierarchy.BOUNDARY,  # 28
        Hierarchy.BOUNDARY,  # 29
        Hierarchy.BOUNDARY,  # 30
        Hierarchy.PRESCRIBED,  # 31 -> (4.5*cos(15 deg), 4.5*sin(15 deg), 1)
        Hierarchy.PRESCRIBED,  # 32 -> *(0, 2, 1)
        Hierarchy.BOUNDARY,  # 33
        Hierarchy.BOUNDARY,  # 34
        Hierarchy.BOUNDARY,  # 35
        Hierarchy.PRESCRIBED,  # 36 -> (4.5*cos(30 deg), 4.5*sin(30 deg), 1)
        Hierarchy.PRESCRIBED,  # 37 -> (0, 3, 1)
        Hierarchy.BOUNDARY,  # 38
        Hierarchy.BOUNDARY,  # 39
        Hierarchy.PRESCRIBED,  # 40 -> (0, 4, 1)
        Hierarchy.PRESCRIBED,  # 41 -> (1.5, 4, 1)
        Hierarchy.PRESCRIBED,  # 42 -> (3.5, 4, 1)
    ),
    prescribed_nodes=(
        (1, Vertex(0, 0, 0)),
        (2, Vertex(1, 0, 0)),
        (3, Vertex(2, 0, 0)),
        (4, Vertex(3, 0, 0)),
        (5, Vertex(4, 0, 0)),
        (6, Vertex(0, 1, 0)),
        (
            10,
            Vertex(
                4.5 * math.cos(15 * DEG2RAD), 4.5 * math.sin(15 * DEG2RAD), 0
            ),
        ),
        (11, Vertex(0, 2, 0)),
        (
            15,
            Vertex(
                4.5 * math.cos(30 * DEG2RAD), 4.5 * math.sin(30 * DEG2RAD), 0
            ),
        ),
        (16, Vertex(0, 3, 0)),
        (19, Vertex(0, 4, 0)),
        (20, Vertex(1.5, 4, 0)),
        (21, Vertex(3.5, 4, 0)),
        (22, Vertex(0, 0, 1)),
        (23, Vertex(1, 0, 1)),
        (24, Vertex(2, 0, 1)),
        (25, Vertex(3, 0, 1)),
        (26, Vertex(4, 0, 1)),
        (27, Vertex(0, 1, 1)),
        (
            31,
            Vertex(
                4.5 * math.cos(15 * DEG2RAD), 4.5 * math.sin(15 * DEG2RAD), 1
            ),
        ),
        (32, Vertex(0, 2, 1)),
        (
            36,
            Vertex(
                4.5 * math.cos(30 * DEG2RAD), 4.5 * math.sin(30 * DEG2RAD), 1
            ),
        ),
        (37, Vertex(0, 3, 1)),
        (40, Vertex(0, 4, 1)),
        (41, Vertex(1.5, 4, 1)),
        (42, Vertex(3.5, 4, 1)),
    ),
    scale_lambda=0.3,
    scale_mu=-0.33,
    num_iters=10,
    algorithm=SmoothingAlgorithm.LAPLACE,
    file_stem="bracket",
)

# Double X two-element example
double_x = Example(
    vertices=(
        Vertex(0.0, 0.0, 0.0),
        Vertex(1.0, 0.0, 0.0),
        Vertex(2.0, 0.0, 0.0),
        Vertex(0.0, 1.0, 0.0),
        Vertex(1.0, 1.0, 0.0),
        Vertex(2.0, 1.0, 0.0),
        Vertex(0.0, 0.0, 1.0),
        Vertex(1.0, 0.0, 1.0),
        Vertex(2.0, 0.0, 1.0),
        Vertex(0.0, 1.0, 1.0),
        Vertex(1.0, 1.0, 1.0),
        Vertex(2.0, 1.0, 1.0),
    ),
    elements=(
        (1, 2, 5, 4, 7, 8, 11, 10),
        (2, 3, 6, 5, 8, 9, 12, 11),
    ),
    nelx=2,
    nely=1,
    nelz=1,
    # neighbors=(
    #     (2, 4, 7),
    #     (1, 3, 5, 8),
    #     (2, 6, 9),
    #     (1, 5, 10),
    #     (2, 4, 6, 11),
    #     (3, 5, 12),
    #     (1, 8, 10),
    #     (2, 7, 9, 11),
    #     (3, 8, 12),
    #     (4, 7, 11),
    #     (5, 8, 10, 12),
    #     (6, 9, 11),
    # ),
    node_hierarchy=(
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
        Hierarchy.BOUNDARY,
    ),
    prescribed_nodes=None,
    scale_lambda=0.3,
    scale_mu=-0.33,
    num_iters=2,
    algorithm=SmoothingAlgorithm.LAPLACE,
    file_stem="double_x",
)

smoothing_figures.py

r"""This module, smoothing_figures.py, illustrates test cases for
smoothing algorithms.

Example
-------
source ~/autotwin/automesh/.venv/bin/activate
cd ~/autotwin/automesh/book/smoothing
python smoothing_figures.py
"""

import datetime
from pathlib import Path
from typing import Final

from matplotlib.colors import LightSource
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

import smoothing as sm
import smoothing_examples as se
import smoothing_types as ty

# Type alias for functional style methods
# https://docs.python.org/3/library/typing.html#type-aliases
# DofSet = ty.DofSet
Hexes = ty.Hexes
Neighbors = ty.Neighbors
NodeHierarchy = ty.NodeHierarchy
Vertex = ty.Vertex
Vertices = ty.Vertices
SmoothingAlgorithm = ty.SmoothingAlgorithm

# Examples
# ex = se.double_x
ex = se.bracket  # overwrite

# Visualization
width, height = 10, 5
# width, height = 8, 4
# width, height = 6, 3
fig = plt.figure(figsize=(width, height))
# fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(1, 2, 1, projection="3d")  # r1, c2, 1st subplot
ax2 = fig.add_subplot(1, 2, 2, projection="3d")  # r1, c2, 2nd subplot

el, az, roll = 63, -110, 0
cmap = plt.get_cmap(name="tab10")
# NUM_COLORS = len(spheres)
NUM_COLORS = 10
VOXEL_ALPHA: Final[float] = 0.9
LINE_ALPHA: Final[float] = 0.5

colors = cmap(np.linspace(0, 1, NUM_COLORS))
lightsource = LightSource(azdeg=325, altdeg=45)  # azimuth, elevation
# lightsource = LightSource(azdeg=325, altdeg=90)  # azimuth, elevation
# OUTPUT_DIR: Final[Path] = Path(__file__).parent
DPI: Final[int] = 300  # resolution, dots per inch
SHOW: Final[bool] = True  # Shows the figure on screen
SAVE: Final[bool] = True  # Saves the .png and .npy files

# output_png_short = ex.file_stem + ".png"
# output_png: Path = (
#     Path(output_dir).expanduser().joinpath(output_png_short)
# )

nx, ny, nz = ex.nelx, ex.nely, ex.nelz
nzp, nyp, nxp = nz + 1, ny + 1, nx + 1
# breakpoint()

vertices_laplace = sm.smooth(
    vv=ex.vertices,
    hexes=ex.elements,
    node_hierarchy=ex.node_hierarchy,
    prescribed_nodes=ex.prescribed_nodes,
    scale_lambda=ex.scale_lambda,
    num_iters=ex.num_iters,
    algorithm=ex.algorithm,
)
# original vertices
xs = [v.x for v in ex.vertices]
ys = [v.y for v in ex.vertices]
zs = [v.z for v in ex.vertices]

# laplace smoothed vertices
xs_l = [v.x for v in vertices_laplace]
ys_l = [v.y for v in vertices_laplace]
zs_l = [v.z for v in vertices_laplace]
# breakpoint()

# draw edge lines
ep = sm.edge_pairs(ex.elements)  # edge pairs
line_segments = [
    (sm.xyz(ex.vertices[p1 - 1]), sm.xyz(ex.vertices[p2 - 1]))
    for (p1, p2) in ep
]
line_segments_laplace = [
    (sm.xyz(vertices_laplace[p1 - 1]), sm.xyz(vertices_laplace[p2 - 1]))
    for (p1, p2) in ep
]
for ls in line_segments:
    x0x1 = [pt[0] for pt in ls]
    y0y1 = [pt[1] for pt in ls]
    z0z1 = [pt[2] for pt in ls]
    ax.plot3D(
        x0x1,
        y0y1,
        z0z1,
        linestyle="solid",
        linewidth=0.5,
        color="blue",
    )
# draw nodes
ax.scatter(
    xs,
    ys,
    zs,
    s=20,
    facecolors="blue",
    edgecolors="none",
)

# repeat with lighter color on second axis
for ls in line_segments:
    x0x1 = [pt[0] for pt in ls]
    y0y1 = [pt[1] for pt in ls]
    z0z1 = [pt[2] for pt in ls]
    ax2.plot3D(
        x0x1,
        y0y1,
        z0z1,
        linestyle="dashed",
        linewidth=0.5,
        color="blue",
        alpha=LINE_ALPHA,
    )
for ls in line_segments_laplace:
    x0x1 = [pt[0] for pt in ls]
    y0y1 = [pt[1] for pt in ls]
    z0z1 = [pt[2] for pt in ls]
    ax2.plot3D(
        x0x1,
        y0y1,
        z0z1,
        linestyle="solid",
        linewidth=0.5,
        color="red",
    )
ax2.scatter(
    xs,
    ys,
    zs,
    s=20,
    facecolors="blue",
    edgecolors="none",
    alpha=0.5,
)

ax2.scatter(
    xs_l,
    ys_l,
    zs_l,
    s=20,
    facecolors="red",
    edgecolors="none",
)

# Set labels for the axes
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
# repeat for the 2nd axis
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.set_zlabel("z")

x_ticks = list(range(nxp))
y_ticks = list(range(nyp))
z_ticks = list(range(nzp))

ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)
ax.set_zticks(z_ticks)
# repeat for the 2nd axis
ax2.set_xticks(x_ticks)
ax2.set_yticks(y_ticks)
ax2.set_zticks(z_ticks)

ax.set_xlim(float(x_ticks[0]), float(x_ticks[-1]))
ax.set_ylim(float(y_ticks[0]), float(y_ticks[-1]))
ax.set_zlim(float(z_ticks[0]), float(z_ticks[-1]))
# repeat for the 2nd axis
ax2.set_xlim(float(x_ticks[0]), float(x_ticks[-1]))
ax2.set_ylim(float(y_ticks[0]), float(y_ticks[-1]))
ax2.set_zlim(float(z_ticks[0]), float(z_ticks[-1]))


# Set the camera view
ax.set_aspect("equal")
ax.view_init(elev=el, azim=az, roll=roll)
# # Set the projection to orthographic
# # ax.view_init(elev=0, azim=-90)  # Adjust the view angle if needed
# repeat for the 2nd axis
ax2.set_aspect("equal")
ax2.view_init(elev=el, azim=az, roll=roll)

# File name
aa = Path(__file__)
fig_path = Path(__file__).parent
# fig_stem = Path(__file__).stem
fig_stem = ex.file_stem
# breakpoint()
FIG_EXT: Final[str] = ".png"
bb = fig_path.joinpath(fig_stem + "_iter_" + str(ex.num_iters) + FIG_EXT)
# Add a footnote
# Get the current date and time in UTC
now_utc = datetime.datetime.now(datetime.UTC)
# Format the date and time as a string
timestamp_utc = now_utc.strftime("%Y-%m-%d %H:%M:%S UTC")
fn = f"Figure: {bb.name} "
fn += f"created with {__file__}\non {timestamp_utc}."
fig.text(0.5, 0.01, fn, ha="center", fontsize=8)

# fig.tight_layout()  # don't use as it clips the x-axis label
if SHOW:
    plt.show()

    if SAVE:
        fig.savefig(bb, dpi=DPI)
        print(f"Saved: {bb}")

print("End of script.")