Pyvista for dolfinx

This package reintroduces simple plot and show functionality to FEniCSx, capable of plotting most common data-structures during both serial and parallel execution.

FEniCSx’ design principles prioritize parallel efficiency, and (in an effort to avoid opaque performance pitfalls) favor fine-grained control over a high-level interface. Consequently, simple, general-purpose plotting routines are no-longer in the standard library. Instead users are deferred to pyvista for their plotting needs. Unfortunately, quite a bit of boiler-plate code is required to interface dolfinx and pyvista. At a prototyping stage one often desires quick-and-dirty visualization with a simple interface. To facilitate this, pyvista4dolfinx provides a single plot function that can be used to plot most of dolfinx visualizable data-structures; scalar- and vector-valued Function, Mesh, FacetMarker, and even integration Measure. The function returns a pyvista.Plotter instance, such that the user still has full access to pyvista’s full range of capabilities.

While this package supports plotting during parallel execution, it should be understood that this comes at a performance penalty; data is gathered and maniputated on a single core. For more data visualization tasks that go beyond basic prototyping, users are encouraged to make use of the vtk export capabilities natively supported in dolfinx.

Demo

The following demo-script illustrates the ease of use. The pyvista4dolfinx-specific lines are highlighted.

 1import os
 2import sys
 3
 4sys.path.insert(0, os.path.abspath("../"))  # Access to pyvista4dolfinx
 5from pyvista4dolfinx import plot, show, Plotter
 6
 7from dolfinx import mesh, fem, io, default_scalar_type
 8from dolfinx.fem.petsc import LinearProblem
 9import matplotlib.pyplot as plt
10from mpi4py import MPI
11import numpy as np
12import ufl
13
14# Mesh and function space
15domain = mesh.create_unit_square(MPI.COMM_WORLD, 20, 20, cell_type=mesh.CellType.quadrilateral)
16V = fem.functionspace(domain, ("Lagrange", 1, (domain.geometry.dim,)))
17
18# Quick visualization of mesh
19plot(domain, show_partitioning=True, show=True)
20
21# Exogeneous data
22u_D = np.array([0, 0], dtype=default_scalar_type)
23T = fem.Constant(domain, default_scalar_type((-0.1, 0)))
24f = fem.Constant(domain, default_scalar_type((0, -5)))
25
26# Boundary tags
27facets_horizontal = mesh.locate_entities_boundary(domain, 
28                                              domain.topology.dim - 1, 
29                                              lambda x: np.isclose(x[0], 0) | np.isclose(x[0], 1))
30facets_bottom = mesh.locate_entities_boundary(domain, 
31                                              domain.topology.dim - 1, 
32                                              lambda x: np.isclose(x[1], 0))
33facets_top = mesh.locate_entities_boundary(domain, 
34                                           domain.topology.dim - 1, 
35                                           lambda x: np.isclose(x[1], 1))
36marked_facets = np.hstack([facets_bottom, facets_top, facets_horizontal])
37marked_values = np.hstack([np.full_like(facets_bottom, 1), np.full_like(facets_top, 2), np.full_like(facets_horizontal, 99)])
38sorting_map = np.argsort(marked_facets)
39facet_tags = mesh.meshtags(domain,domain.topology.dim-1,
40                           marked_facets[sorting_map],
41                           marked_values[sorting_map])
42
43# Volume tags
44def volume_inclusion(x):
45    return (x[0] <= 0.75) & (x[0] >= 0.25) & (x[1] <= 0.75) & (x[1] >= 0.25)
46elements_inclusion = np.sort(mesh.locate_entities(domain,domain.topology.dim,volume_inclusion))
47element_tags = mesh.meshtags(domain,domain.topology.dim,elements_inclusion,3)
48
49# Quick visualization of tags
50plot(domain)
51plot(element_tags, domain, name="Element tags", cmap="cool", \
52                    scalar_bar_args={"vertical": True,"position_x": 0.85,"position_y": 0.25})
53plot(facet_tags, domain, name="Facet tags", show=True, 
54                    scalar_bar_args={"position_x": 0.21,"position_y": 0.05})
55
56# Measures from tags
57ds = ufl.Measure("ds", domain=domain, subdomain_data=facet_tags)
58dx = ufl.Measure("dx", domain=domain, subdomain_data=element_tags)
59
60# Quick visualization of integration measure
61plot(dx(3), domain, name="Volume measure", cmap="cool", \
62                    scalar_bar_args={"vertical": True,"position_x": 0.85,"position_y": 0.25})
63plot(ds(2), domain, name="Surface measure", show=True, 
64                    scalar_bar_args={"position_x": 0.21,"position_y": 0.05})
65
66# Clamping condition on the bottom
67Vdofs_bottom = fem.locate_dofs_topological(V, domain.topology.dim - 1, facets_bottom)
68bc = fem.dirichletbc(u_D, Vdofs_bottom, V)
69
70# Linear elasticity formulation
71def epsilon(u):
72    return ufl.sym(ufl.grad(u))
73def sigma(u):
74    return ufl.nabla_div(u) * ufl.Identity(len(u)) + 2 * epsilon(u)
75u = ufl.TrialFunction(V)
76v = ufl.TestFunction(V)
77a = ufl.inner(sigma(u), epsilon(v)) * dx
78L = ufl.dot(f, v) * dx(3) + ufl.dot(T, v) * ds(2)
79
80# Solve problem
81problem = LinearProblem(
82    a, L, bcs=[bc], petsc_options={"ksp_type": "preonly", "pc_type": "lu"}
83)
84uh = problem.solve()
85uh.name = "Displacement"
86
87# Post process; von Mises stress
88s = sigma(uh) - 1. / 3 * ufl.tr(sigma(uh)) * ufl.Identity(len(uh))
89von_Mises = ufl.sqrt(3. / 2 * ufl.inner(s, s))
90V_von_mises = fem.functionspace(domain, ("DG", 0))
91stress_expr = fem.Expression(von_Mises, V_von_mises.element.interpolation_points())
92stresses = fem.Function(V_von_mises, name='Stress')
93stresses.interpolate(stress_expr)
94
95# Final visualization
96plotter = Plotter(shape=(1, 2))
97plot(uh, plotter=plotter, show_mesh=False, subplot=(0,0), factor=0.5)
98plot(stresses, warp=uh, show_mesh=False, subplot=(0,1))
99show()

When executed on three cores, the following plot routines output the below graphics.

plot(domain, show_partitioning=True, show=True)
plot_mesh
plot(domain)
plot(element_tags, domain, name="Element tags", cmap="cool", \
     scalar_bar_args={"vertical": True,"position_x": 0.85,"position_y": 0.25})
plot(facet_tags, domain, name="Facet tags", show=True,
     scalar_bar_args={"position_x": 0.21,"position_y": 0.05})
plot_meshtags
plot(dx(3), domain, name="Volume measure", cmap="cool", \
     scalar_bar_args={"vertical": True,"position_x": 0.85,"position_y": 0.25})
plot(ds(2), domain, name="Surface measure", show=True,
     scalar_bar_args={"position_x": 0.21,"position_y": 0.05})
plot_measure
plotter = Plotter(shape=(1, 2))
plot(uh, plotter=plotter, show_mesh=False, subplot=(0,0), factor=0.5)
plot(stresses, warp=uh, show_mesh=False, subplot=(0,1))
show()
plot_functions

Documentation