Skip to content

Commit

Permalink
wave3d: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Aug 3, 2023
1 parent e54cd7b commit 625e976
Showing 1 changed file with 26 additions and 33 deletions.
59 changes: 26 additions & 33 deletions fast/wave3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,45 +88,41 @@ def plot_3dfunc(u):
# First, position source centrally in all dimensions, then set depth
src.coordinates.data[0, :] = np.array(model.domain_size) * .5

# We can plot the time signature to see the wavelet
# src.show()

# Define the wavefield with the size of the model and the time dimension
u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so)
# Another one to clone data
u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so)
ub = TimeFunction(name="ub", grid=model.grid, time_order=to, space_order=so)


# We can now write the PDE
# pde = model.m * u.dt2 - u.laplace + model.damp * u.dt
# import pdb;pdb.set_trace()
pde = u.dt2 - u.laplace

# The PDE representation is as on paper
# pde

stencil = Eq(u.forward, solve(pde, u.forward))
stencil

# Finally we define the source injection and receiver read function to generate
# the corresponding code
# print(time_range)

print("Init norm:", np.linalg.norm(u.data[:]))
# print("Init norm:", np.linalg.norm(u.data[:]))
src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m)
op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator')
# Run with source and plot
op0.apply(time=time_range.num-1, dt=model.critical_dt)


if len(shape) == 3:
if args.plot:
plot_3dfunc(u)

devito_norm = norm(u)
print("Init linalg norm 0 (inlined) :", norm(u))
print("Init linalg norm 0 :", np.linalg.norm(u.data[0]))
print("Init linalg norm 1 :", np.linalg.norm(u.data[1]))
print("Init linalg norm 2 :", np.linalg.norm(u.data[2]))
print("Norm of initial data:", np.linalg.norm(u.data[:]))
# devito_norm = norm(u)
# print("Init linalg norm 0 (inlined) :", norm(u))
# print("Init linalg norm 0 :", np.linalg.norm(u.data[0]))
# print("Init linalg norm 1 :", np.linalg.norm(u.data[1]))
# print("Init linalg norm 2 :", np.linalg.norm(u.data[2]))
# print("Norm of initial data:", np.linalg.norm(u.data[:]))

configuration['mpi'] = 0
u2.data[:] = u.data[:]
Expand All @@ -136,36 +132,33 @@ def plot_3dfunc(u):
op1 = Operator([stencil], name='DevitoOperator')
op1.apply(time=time_range.num-1, dt=model.critical_dt)

configuration['mpi'] = 0
ub.data[:] = u.data[:]
configuration['mpi'] = 'basic'

if len(shape) == 3:
if args.plot:
plot_3dfunc(u)

#devito_output = u.data[:]
print("After Operator 1: Devito norm:", np.linalg.norm(u.data[:]))
print("Devito linalg norm 0:", np.linalg.norm(u.data[0]))
print("Devito linalg norm 1:", np.linalg.norm(u.data[1]))
print("Devito linalg norm 2:", np.linalg.norm(u.data[2]))
# print("After Operator 1: Devito norm:", np.linalg.norm(u.data[:]))
#print("Devito norm 0:", np.linalg.norm(u.data[0]))
#print("Devito norm 1:", np.linalg.norm(u.data[1]))
#print("Devito norm 2:", np.linalg.norm(u.data[2]))

# Reset initial data
configuration['mpi'] = 0
u.data[:] = u2.data[:]
configuration['mpi'] = 'basic'
#v[:, ..., :] = 1

print("Reinitialise data for XDSL:", np.linalg.norm(u.data[:]))
print("Init XDSL linalg norm 0:", np.linalg.norm(u.data[0]))
print("Init XDSL linalg norm 1:", np.linalg.norm(u.data[1]))
print("Init XDSL linalg norm 2:", np.linalg.norm(u.data[2]))
# print("Reinitialise data for XDSL:", np.linalg.norm(u.data[:]))
# print("Init XDSL linalg norm 0:", np.linalg.norm(u.data[0]))
# print("Init XDSL linalg norm 1:", np.linalg.norm(u.data[1]))
# print("Init XDSL linalg norm 2:", np.linalg.norm(u.data[2]))

# Run more with no sources now (Not supported in xdsl)
xdslop = Operator([stencil], name='Operator')
xdslop = XDSLOperator([stencil], name='XDSLOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)

xdsl_output = u.copy()

print("XDSL norm:", norm(u))
print(f"xdsl output norm: {norm(xdsl_output)}")
print("XDSL output linalg norm 0:", np.linalg.norm(u.data[0]))
print("XDSL output linalg norm 1:", np.linalg.norm(u.data[1]))
print("XDSL output linalg norm 2:", np.linalg.norm(u.data[2]))
print("devito-norm:", devito_norm)
print("XDSL output norm 0:", np.linalg.norm(u.data[0]), "vs:", np.linalg.norm(ub.data[0]))
print("XDSL output norm 1:", np.linalg.norm(u.data[1]), "vs:", np.linalg.norm(ub.data[1]))
print("XDSL output norm 2:", np.linalg.norm(u.data[2]), "vs:", np.linalg.norm(ub.data[2]))

0 comments on commit 625e976

Please sign in to comment.