Skip to content

Commit

Permalink
automated DFT decimation for adjoint sources (#1753)
Browse files Browse the repository at this point in the history
* remove tol parameter from get_fwidth and add new set_fwidth function

* add fwidth parameter to custom_src_time class

* compute bandwidth of Nuttall DTFT window function using asymptotic power law

* fixes and docs

* add set_fwidth and get_fwidth functions to custom_py_src_time class

* adjust tolerances of failing unit tests

* remove fwidth from formula for fitting coefficient

* revert changes to tolerances in unit tests

* slightly increase tolerance for single-precision case of DFT fields test
  • Loading branch information
oskooi authored Oct 13, 2021
1 parent 16f4526 commit 8f0deee
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 58 deletions.
6 changes: 6 additions & 0 deletions doc/docs/Python_User_Interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -6627,6 +6627,7 @@ def __init__(self,
start_time=-1e+20,
end_time=1e+20,
center_frequency=0,
fwidth=0,
**kwargs):
```

Expand Down Expand Up @@ -6658,6 +6659,11 @@ Construct a `CustomSource`.
+ **`center_frequency` [`number`]** — Optional center frequency so that the
`CustomSource` can be used within an `EigenModeSource`. Defaults to 0.

+ **`fwidth` [`number`]** — Optional bandwidth in frequency units.
Default is 0. For bandwidth-limited sources, this parameter is used to
automatically determine the decimation factor of the time-series updates
of the DFT fields monitors (if any).

</div>

</div>
Expand Down
24 changes: 22 additions & 2 deletions python/adjoint/filter_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(
self.t = np.arange(0, dt * (self.N), dt)
self.n = np.arange(self.N)
f = self.func()

# frequency bandwidth of the Nuttall window function
fwidth = self.nuttall_bandwidth()

self.bf = [
lambda t, i=i: 0
if t > self.T else (self.nuttall(t, self.center_frequencies) /
Expand All @@ -36,7 +40,8 @@ def __init__(
CustomSource(src_func=bfi,
center_frequency=center_frequency,
is_integrated=False,
end_time=self.T) for bfi in self.bf
end_time=self.T,
fwidth=fwidth) for bfi in self.bf
]

if time_src:
Expand All @@ -58,7 +63,8 @@ def __init__(
super(FilteredSource, self).__init__(src_func=f,
center_frequency=center_frequency,
is_integrated=False,
end_time=self.T)
end_time=self.T,
fwidth=fwidth)

def cos_window_td(self, a, t, f0):
cos_sum = np.sum([(-1)**k * a[k] * np.cos(2 * np.pi * t * k / self.T)
Expand Down Expand Up @@ -103,6 +109,20 @@ def nuttall_dtft(self, f, f0):
a = [0.355768, 0.4873960, 0.144232, 0.012604]
return self.cos_window_fd(a, f, f0)

## compute the bandwidth of the DTFT of the Nuttall window function
## (magnitude) assuming it has decayed from its peak value by some
## tolerance by fitting it to an asymptotic power law of the form
## C / f^3 where C is a constant and f is the frequency
def nuttall_bandwidth(self):
tol = 1e-7
fwidth = 1/(self.N * self.dt)
frq_inf = 10000*fwidth
na_dtft = self.nuttall_dtft(frq_inf, 0)
coeff = frq_inf**3 * np.abs(na_dtft)
na_dtft_max = self.nuttall_dtft(0, 0)
bw = 2 * np.power(coeff / (tol * na_dtft_max), 1/3)
return bw.real

def dtft(self, y, f):
return np.matmul(
np.exp(1j * 2 * np.pi * f[:, np.newaxis] * np.arange(y.size) *
Expand Down
10 changes: 6 additions & 4 deletions python/meep-python.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace meep {
class custom_py_src_time : public src_time {
public:
custom_py_src_time(PyObject *fun, double st = -infinity, double et = infinity,
std::complex<double> f = 0)
: func(fun), freq(f), start_time(float(st)), end_time(float(et)) {
std::complex<double> f = 0, double fw = 0)
: func(fun), freq(f), start_time(float(st)), end_time(float(et)), fwidth(fw) {
SWIG_PYTHON_THREAD_SCOPED_BLOCK;
Py_INCREF(func);
}
Expand Down Expand Up @@ -50,17 +50,19 @@ class custom_py_src_time : public src_time {
const custom_py_src_time *tp = dynamic_cast<const custom_py_src_time *>(&t);
if (tp)
return (tp->start_time == start_time && tp->end_time == end_time && tp->func == func &&
tp->freq == freq);
tp->freq == freq && tp->fwidth == fwidth);
else
return 0;
}
virtual std::complex<double> frequency() const { return freq; }
virtual void set_frequency(std::complex<double> f) { freq = f; }
virtual double get_fwidth() const { return fwidth; };
virtual void set_fwidth(double fw) { fwidth = fw; }

private:
PyObject *func;
std::complex<double> freq;
double start_time, end_time;
double start_time, end_time, fwidth;
};

} // namespace meep
11 changes: 9 additions & 2 deletions python/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class CustomSource(SourceTime):
[`examples/chirped_pulse.py`](https://github.com/NanoComp/meep/blob/master/python/examples/chirped_pulse.py).
"""

def __init__(self, src_func, start_time=-1.0e20, end_time=1.0e20, center_frequency=0, **kwargs):
def __init__(self, src_func, start_time=-1.0e20, end_time=1.0e20, center_frequency=0, fwidth=0, **kwargs):
"""
Construct a `CustomSource`.
Expand All @@ -296,13 +296,20 @@ def __init__(self, src_func, start_time=-1.0e20, end_time=1.0e20, center_frequen
+ **`center_frequency` [`number`]** — Optional center frequency so that the
`CustomSource` can be used within an `EigenModeSource`. Defaults to 0.
+ **`fwidth` [`number`]** — Optional bandwidth in frequency units.
Default is 0. For bandwidth-limited sources, this parameter is used to
automatically determine the decimation factor of the time-series updates
of the DFT fields monitors (if any).
"""
super(CustomSource, self).__init__(**kwargs)
self.src_func = src_func
self.start_time = start_time
self.end_time = end_time
self.fwidth = fwidth
self.center_frequency = center_frequency
self.swigobj = mp.custom_py_src_time(src_func, start_time, end_time, center_frequency)
self.swigobj = mp.custom_py_src_time(src_func, start_time, end_time,
center_frequency, fwidth)
self.swigobj.is_integrated = self.is_integrated


Expand Down
35 changes: 20 additions & 15 deletions python/tests/test_adjoint_cyl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
design_region_resolution = int(2*resolution)
design_r = 4.8
design_z = 2
Nx = int(design_region_resolution*design_r)
Nz = int(design_region_resolution*design_z)
Nr = int(design_region_resolution*design_r) + 1
Nz = int(design_region_resolution*design_z) + 1

fcen = 1/1.55
width = 0.2
Expand All @@ -37,20 +37,20 @@
src = mp.GaussianSource(frequency=fcen,fwidth=fwidth)
source = [mp.Source(src,component=mp.Er,
center=source_center,
size=source_size)]
size=source_size)]

## random design region
p = np.random.rand(Nx*Nz)
p = np.random.rand(Nr*Nz)
## random epsilon perturbation for design region
deps = 1e-5
dp = deps*np.random.rand(Nx*Nz)
dp = deps*np.random.rand(Nr*Nz)


def forward_simulation(design_params):
matgrid = mp.MaterialGrid(mp.Vector3(Nx,0,Nz),
matgrid = mp.MaterialGrid(mp.Vector3(Nr,0,Nz),
SiO2,
Si,
weights=design_params.reshape(Nx,1,Nz))
weights=design_params.reshape(Nr,1,Nz))

geometry = [mp.Block(center=mp.Vector3(0.1+design_r/2,0,0),
size=mp.Vector3(design_r,0,design_z),
Expand All @@ -68,9 +68,8 @@ def forward_simulation(design_params):
far_x = [mp.Vector3(5,0,20)]
mode = sim.add_near2far(
frequencies,
mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0 ,(sz/2-dpml+design_z/2)/2),size=mp.Vector3(design_r,0,0), weight=+1),
decimation_factor=10
)
mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0,(sz/2-dpml+design_z/2)/2),
size=mp.Vector3(design_r,0,0),weight=+1))

sim.run(until_after_sources=1200)
Er = sim.get_farfield(mode, far_x[0])
Expand All @@ -81,9 +80,13 @@ def forward_simulation(design_params):

def adjoint_solver(design_params):

design_variables = mp.MaterialGrid(mp.Vector3(Nx,0,Nz),SiO2,Si)
design_region = mpa.DesignRegion(design_variables,volume=mp.Volume(center=mp.Vector3(0.1+design_r/2,0,0), size=mp.Vector3(design_r, 0,design_z)))
geometry = [mp.Block(center=design_region.center, size=design_region.size, material=design_variables)]
design_variables = mp.MaterialGrid(mp.Vector3(Nr,0,Nz),SiO2,Si)
design_region = mpa.DesignRegion(design_variables,
volume=mp.Volume(center=mp.Vector3(0.1+design_r/2,0,0),
size=mp.Vector3(design_r,0,design_z)))
geometry = [mp.Block(center=design_region.center,
size=design_region.size,
material=design_variables)]

sim = mp.Simulation(cell_size=cell_size,
boundary_layers=boundary_layers,
Expand All @@ -94,8 +97,10 @@ def adjoint_solver(design_params):
m=m)

far_x = [mp.Vector3(5,0,20)]
NearRegions = [mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0 ,(sz/2-dpml+design_z/2)/2),size=mp.Vector3(design_r,0,0), weight=+1)]
FarFields = mpa.Near2FarFields(sim, NearRegions ,far_x, decimation_factor=5)
NearRegions = [mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0,(sz/2-dpml+design_z/2)/2),
size=mp.Vector3(design_r,0,0),
weight=+1)]
FarFields = mpa.Near2FarFields(sim, NearRegions ,far_x)
ob_list = [FarFields]

def J(alpha):
Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_adjoint_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def build_straight_wg_simulation(
mpa.EigenmodeCoefficient(simulation,
mp.Volume(center=center, size=monitor_size),
mode=1,
forward=forward,
decimation_factor=5)
forward=forward)
for center in monitor_centers for forward in [True, False]
]
return simulation, sources, monitors, design_regions, frequencies
Expand Down
17 changes: 6 additions & 11 deletions python/tests/test_adjoint_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,14 @@ def forward_simulation(design_params,mon_type, frequencies=None, use_complex=Fal
mp.ModeRegion(center=mp.Vector3(0.5*sxy-dpml-0.1),
size=mp.Vector3(0,sxy-2*dpml,0)),
yee_grid=True,
decimation_factor=10,
eig_parity=eig_parity)

elif mon_type.name == 'DFT':
mode = sim.add_dft_fields([mp.Ez],
frequencies,
center=mp.Vector3(1.25),
size=mp.Vector3(0.25,1,0),
yee_grid=False,
decimation_factor=10)
yee_grid=False)

sim.run(until_after_sources=mp.stop_when_dft_decayed())

Expand Down Expand Up @@ -145,7 +143,6 @@ def adjoint_solver(design_params, mon_type, frequencies=None, use_complex=False,
mp.Volume(center=mp.Vector3(0.5*sxy-dpml-0.1),
size=mp.Vector3(0,sxy-2*dpml,0)),
1,
decimation_factor=5,
eig_parity=eig_parity)]

def J(mode_mon):
Expand All @@ -155,8 +152,7 @@ def J(mode_mon):
obj_list = [mpa.FourierFields(sim,
mp.Volume(center=mp.Vector3(1.25),
size=mp.Vector3(0.25,1,0)),
mp.Ez,
decimation_factor=5)]
mp.Ez)]

def J(mode_mon):
return npa.power(npa.abs(mode_mon[:,4,10]),2)
Expand All @@ -166,8 +162,7 @@ def J(mode_mon):
objective_functions=J,
objective_arguments=obj_list,
design_regions=[matgrid_region],
frequencies=frequencies,
decimation_factor=10)
frequencies=frequencies)

f, dJ_du = opt([design_params])

Expand Down Expand Up @@ -213,7 +208,7 @@ def test_adjoint_solver_DFT_fields(self):
adj_scale = (dp[None,:]@adjsol_grad).flatten()
fd_grad = S12_perturbed-S12_unperturbed
print("Directional derivative -- adjoint solver: {}, FD: {}".format(adj_scale,fd_grad))
tol = 0.04 if mp.is_single_precision() else 0.005
tol = 0.0461 if mp.is_single_precision() else 0.005
self.assertClose(adj_scale,fd_grad,epsilon=tol)


Expand Down Expand Up @@ -267,14 +262,14 @@ def test_gradient_backpropagation(self):
bp_adjsol_grad = tensor_jacobian_product(mapping,0)(p,filter_radius,eta,beta,adjsol_grad)

## compute unperturbed S12
S12_unperturbed = forward_simulation(mapped_p, MonitorObject.EIGENMODE,frequencies)
S12_unperturbed = forward_simulation(mapped_p,MonitorObject.EIGENMODE,frequencies)

## compare objective results
print("S12 -- adjoint solver: {}, traditional simulation: {}".format(adjsol_obj,S12_unperturbed))
self.assertClose(adjsol_obj,S12_unperturbed,epsilon=1e-6)

## compute perturbed S12
S12_perturbed = forward_simulation(mapping(p+dp,filter_radius,eta,beta), MonitorObject.EIGENMODE,frequencies)
S12_perturbed = forward_simulation(mapping(p+dp,filter_radius,eta,beta),MonitorObject.EIGENMODE,frequencies)

if bp_adjsol_grad.ndim < 2:
bp_adjsol_grad = np.expand_dims(bp_adjsol_grad,axis=1)
Expand Down
7 changes: 3 additions & 4 deletions src/dft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,12 @@ dft_chunk *fields::add_dft(component c, const volume &where, const double *freq,
data.vc = vc;

if (decimation_factor == 0) {
double tol = 1e-7;
double src_freq_max = 0;
for (src_time *s = sources; s; s = s->next) {
if (s->get_fwidth(tol) == 0)
if (s->get_fwidth() == 0)
decimation_factor = 1;
else
src_freq_max = std::max(src_freq_max, std::abs(s->frequency().real())+0.5*s->get_fwidth(tol));
src_freq_max = std::max(src_freq_max, std::abs(s->frequency().real())+0.5*s->get_fwidth());
}
double freq_max = 0;
for (size_t i = 0; i < Nfreq; ++i)
Expand Down Expand Up @@ -1376,4 +1375,4 @@ void fields::get_mode_mode_overlap(void *mode1_data, void *mode2_data, dft_flux
get_overlap(mode1_data, mode2_data, flux, 0, overlaps);
}

} // namespace meep
} // namespace meep
19 changes: 11 additions & 8 deletions src/meep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ class src_time {
return 1;
}
virtual std::complex<double> frequency() const { return 0.0; }
virtual double get_fwidth(double tol) const { (void)tol; return 0.0; }
virtual double get_fwidth() const { return 0.0; }
virtual void set_fwidth(double fw) { (void)fw; }
virtual void set_frequency(std::complex<double> f) { (void)f; }

private:
Expand All @@ -1010,12 +1011,13 @@ class gaussian_src_time : public src_time {
virtual src_time *clone() const { return new gaussian_src_time(*this); }
virtual bool is_equal(const src_time &t) const;
virtual std::complex<double> frequency() const { return freq; }
virtual double get_fwidth(double tol) const;
virtual double get_fwidth() const { return fwidth; };
virtual void set_fwidth(double fw) { fwidth = fw; };
virtual void set_frequency(std::complex<double> f) { freq = real(f); }
std::complex<double> fourier_transform(const double f);

private:
double freq, width, peak_time, cutoff;
double freq, fwidth, width, peak_time, cutoff;
};

// Continuous (CW) source with (optional) slow turn-on and/or turn-off.
Expand All @@ -1031,7 +1033,7 @@ class continuous_src_time : public src_time {
virtual src_time *clone() const { return new continuous_src_time(*this); }
virtual bool is_equal(const src_time &t) const;
virtual std::complex<double> frequency() const { return freq; }
virtual double get_fwidth(double tol) const { (void)tol; return 0.0; };
virtual double get_fwidth() const { return 0.0; };
virtual void set_frequency(std::complex<double> f) { freq = f; }

private:
Expand All @@ -1043,8 +1045,8 @@ class continuous_src_time : public src_time {
class custom_src_time : public src_time {
public:
custom_src_time(std::complex<double> (*func)(double t, void *), void *data, double st = -infinity,
double et = infinity, std::complex<double> f = 0)
: func(func), data(data), freq(f), start_time(float(st)), end_time(float(et)) {}
double et = infinity, std::complex<double> f = 0, double fw = 0)
: func(func), data(data), freq(f), start_time(float(st)), end_time(float(et)), fwidth(fw) {}
virtual ~custom_src_time() {}

virtual std::complex<double> current(double time, double dt) const {
Expand All @@ -1064,14 +1066,15 @@ class custom_src_time : public src_time {
virtual src_time *clone() const { return new custom_src_time(*this); }
virtual bool is_equal(const src_time &t) const;
virtual std::complex<double> frequency() const { return freq; }
virtual double get_fwidth(double tol) const { (void)tol; return 0.0; };
virtual void set_frequency(std::complex<double> f) { freq = f; }
virtual double get_fwidth() const { return fwidth; };
virtual void set_fwidth(double fw) { fwidth = fw; }

private:
std::complex<double> (*func)(double t, void *);
void *data;
std::complex<double> freq;
double start_time, end_time;
double start_time, end_time, fwidth;
};

class monitor_point {
Expand Down
Loading

0 comments on commit 8f0deee

Please sign in to comment.