-
Notifications
You must be signed in to change notification settings - Fork 82
Open
Description
Hi Jax-Fem community,
I just installed Jax-fem on my Mac following the instructions.
While I tried to run the examples, all have this RuntimeError
For example, when I ran python -m demos.wave.example
[11-12 13:50:24][DEBUG] jax_fem: Computing shape function values, gradients, etc.
[11-12 13:50:24][DEBUG] jax_fem: ele_type = TRI3, quad_points.shape = (num_quads, dim) = (3, 2)
[11-12 13:50:24][DEBUG] jax_fem: face_quad_points.shape = (num_faces, num_face_quads, dim) = (3, 2, 2)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/Users/tong/Desktop/github_tryouts/jax-fem/demos/wave/example.py", line 184, in <module>
main_fns()
File "/Users/tong/Desktop/github_tryouts/jax-fem/demos/wave/example.py", line 166, in main_fns
problem = wave(mesh, vec=1, dim=2, ele_type = ele_type, gauss_order=2, dirichlet_bc_info = dirichlet_bc_info)
File "<string>", line 11, in __init__
File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/problem.py", line 37, in __post_init__
self.fes = [FiniteElement(mesh=self.mesh[I],
File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/problem.py", line 37, in <listcomp>
self.fes = [FiniteElement(mesh=self.mesh[I],
File "<string>", line 10, in __init__
File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/fe.py", line 79, in __post_init__
self.node_inds_list, self.vec_inds_list, self.vals_list = self.Dirichlet_boundary_conditions(self.dirichlet_bc_info)
File "/Users/tong/Desktop/github_tryouts/jax-fem/jax_fem/fe.py", line 221, in Dirichlet_boundary_conditions
node_inds = onp.argwhere(jax.vmap(location_fn)(self.mesh.points, np.arange(self.num_total_nodes))).reshape(-1)
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3582, in arange
return lax.iota(dtype, start)
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1321, in iota
return broadcasted_iota(dtype, (size,), 0)
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1331, in broadcasted_iota
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/core.py", line 416, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/core.py", line 420, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/core.py", line 921, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
File "/Users/tong/anaconda3/envs/jax-fem-env/lib/python3.9/site-packages/jax/_src/array.py", line 1146, in _array_global_result_handler
return xc.array_result_handler(
RuntimeError: std::bad_cast
Metadata
Metadata
Assignees
Labels
No labels