From ec7126b4f2f5d6432a0eed72cecdeacc6216b347 Mon Sep 17 00:00:00 2001 From: Simon White Date: Wed, 29 Nov 2023 14:03:13 +0100 Subject: [PATCH] Beam slice monitor (#692) * add passmethod * passmethod for slice moments * passmethod for slice moments * add python element * update tracking utils * correct ordering * correct mex name * correct mex bug * zcuts in mex * bugfix * bugfix * fix initialization * return nan instead of 0 * spos never 0 * change shape of attributes * added spos attribute * bugfix * add nslice to build attributes * reshape output * set spos as slice center * help + pep8 * help + pep8 * handle weights for zero beam current * spos centered on bucket not on bunch position in the ring --- atintegrators/BeamMomentsPass.c | 4 +- atintegrators/SliceMomentsPass.c | 235 +++++++++++++++++++++++++++++++ pyat/at/lattice/elements.py | 130 +++++++++++++++-- pyat/at/tracking/utils.py | 3 +- 4 files changed, 360 insertions(+), 12 deletions(-) create mode 100644 atintegrators/SliceMomentsPass.c diff --git a/atintegrators/BeamMomentsPass.c b/atintegrators/BeamMomentsPass.c index f07dde72e..3af8ab98b 100644 --- a/atintegrators/BeamMomentsPass.c +++ b/atintegrators/BeamMomentsPass.c @@ -122,8 +122,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) else if (nrhs == 0) { /* list of required fields */ plhs[0] = mxCreateCellMatrix(2,1); - mxSetCell(plhs[0],0,mxCreateString("_positions")); - mxSetCell(plhs[0],1,mxCreateString("_sizes")); + mxSetCell(plhs[0],0,mxCreateString("_means")); + mxSetCell(plhs[0],1,mxCreateString("_stds")); } else { mexErrMsgIdAndTxt("AT:WrongArg","Needs 2 or 0 arguments"); diff --git a/atintegrators/SliceMomentsPass.c b/atintegrators/SliceMomentsPass.c new file mode 100644 index 000000000..8ba718a78 --- /dev/null +++ b/atintegrators/SliceMomentsPass.c @@ -0,0 +1,235 @@ +#include "atelem.c" +#include "atimplib.c" +#include +#include +#include +#ifdef MPI +#include +#include +#endif + + +struct elem +{ + int startturn; + int endturn; + int turn; + int nslice; + double *stds; + double *means; + double *sposs; + double *weights; + double *z_cuts; +}; + + +static void slice_beam(double *r_in,int num_particles,int nslice,int turn, + int nturns, int nbunch, double *weights, double *sposs, + double *means, double *stds, double *z_cuts, + double *bunch_currents, double beam_current){ + + int i,ii,iii,ib; + double *rtmp; + + double *smin = atMalloc(nbunch*sizeof(double)); + double *smax = atMalloc(nbunch*sizeof(double)); + double *hz = atMalloc(nbunch*sizeof(double)); + double *np_bunch = atMalloc(nbunch*sizeof(double)); + getbounds(r_in,nbunch,num_particles,smin,smax,z_cuts); + + for(i=0;i= smin[ib]) && (rtmp[5] <= smax[ib])) { + if (rtmp[5] == smax[ib]){ + ii = nslice-1 + ib*nslice; + } + else { + ii = (int)(floor((rtmp[5]-smin[ib])/hz[ib])) + ib*nslice; + } + weight[ii] += 1.0; + for(iii=0; iii<3; iii++) { + pos[iii+ii*3] += rtmp[idx[iii]]; + std[iii+ii*3] += rtmp[idx[iii]]*rtmp[idx[iii]]; + } + } + } + + #ifdef MPI + MPI_Allreduce(MPI_IN_PLACE,np_bunch,nbunch,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE,pos,3*nslice*nbunch,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE,std,3*nslice*nbunch,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE,weight,nslice*nbunch,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD); + MPI_Barrier(MPI_COMM_WORLD); + #endif + + for (i=0;i 0){ + pos[3*i+ii] = pos[3*i+ii]/weight[i]; + std[3*i+ii] = sqrt(std[3*i+ii]/weight[i]-pos[3*i+ii]*pos[3*i+ii]); + } + else{ + pos[3*i+ii] = NAN; + std[3*i+ii] = NAN; + } + } + spos[i] = smin[ib]+(i%nslice+0.5)*hz[ib]; + if(beam_current>0.0){ + weight[i] *= bunch_currents[ib]/np_bunch[ib]; + }else{ + weight[i] *= 1.0/np_bunch[ib]; + } + } + + means += 3*nbunch*nslice*turn; + stds += 3*nbunch*nslice*turn; + sposs += nbunch*nslice*turn; + weights += nbunch*nslice*turn; + memcpy(means, pos, 3*nbunch*nslice*sizeof(double)); + memcpy(stds, std, 3*nbunch*nslice*sizeof(double)); + memcpy(sposs, spos, nbunch*nslice*sizeof(double)); + memcpy(weights, weight, nbunch*nslice*sizeof(double)); + + atFree(buffer); + atFree(np_bunch); + atFree(smin); + atFree(smax); + atFree(hz); +}; + + +void SliceMomentsPass(double *r_in, int nbunch, double *bunch_currents, + double beam_current, int num_particles, struct elem *Elem) { + + int startturn = Elem->startturn; + int endturn = Elem->endturn; + int nturns = endturn-startturn; + int turn = Elem->turn; + int nslice = Elem->nslice; + double *stds = Elem->stds; + double *means = Elem->means; + double *sposs = Elem->sposs; + double *weights = Elem->weights; + double *z_cuts = Elem->z_cuts; + + if((turn>=startturn) && (turn Param->num_turns){ + atWarning("endturn exceed the total number of turns"); + }; + int dims[] = {3, Param->nbunch*nslice, endturn-startturn}; + int dimsw[] = {Param->nbunch*nslice, endturn-startturn}; + means = atGetDoubleArray(ElemData,"_means"); check_error(); + stds = atGetDoubleArray(ElemData,"_stds"); check_error(); + sposs = atGetDoubleArray(ElemData,"_spos"); check_error(); + weights = atGetDoubleArray(ElemData,"_weights"); check_error(); + z_cuts=atGetOptionalDoubleArray(ElemData,"ZCuts"); check_error(); + atCheckArrayDims(ElemData,"_means", 3, dims); check_error(); + atCheckArrayDims(ElemData,"_stds", 3, dims); check_error(); + atCheckArrayDims(ElemData,"_spos", 2, dimsw); check_error(); + atCheckArrayDims(ElemData,"_weights", 2, dimsw); check_error(); + Elem = (struct elem*)atMalloc(sizeof(struct elem)); + Elem->stds = stds; + Elem->means = means; + Elem->sposs = sposs; + Elem->weights = weights; + Elem->turn = 0; + Elem->startturn = startturn; + Elem->endturn = endturn; + Elem->nslice = nslice; + Elem->z_cuts = z_cuts; + } + SliceMomentsPass(r_in, Param->nbunch, Param->bunch_currents, + Param->beam_current, num_particles, Elem); + Elem->turn++; + return Elem; +} + +MODULE_DEF(SliceMomentsPass) /* Dummy module initialisation */ +#endif + +#ifdef MATLAB_MEX_FILE + +void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) +{ + if (nrhs == 2) { + + double *r_in; + const mxArray *ElemData = prhs[0]; + int num_particles = mxGetN(prhs[1]); + struct elem El, *Elem=&El; + + double *means, *stds, *weights, *z_cuts; + int startturn = atGetLong(ElemData,"startturn"); check_error(); + int endturn = atGetLong(ElemData,"endturn"); check_error(); + int nslice = atGetLong(ElemData,"nslice"); check_error(); + means = atGetDoubleArray(ElemData,"_means"); check_error(); + stds = atGetDoubleArray(ElemData,"_stds"); check_error(); + weights = atGetDoubleArray(ElemData,"_weights"); check_error(); + z_cuts=atGetOptionalDoubleArray(ElemData,"ZCuts"); check_error(); + Elem = (struct elem*)atMalloc(sizeof(struct elem)); + Elem->stds = stds; + Elem->means = means; + Elem->weights = weights; + Elem->turn = 0; + Elem->startturn = startturn; + Elem->endturn = endturn; + Elem->nslice = nslice; + Elem->z_cuts = z_cuts; + if (mxGetM(prhs[1]) != 6) mexErrMsgIdAndTxt("AT:WrongArg","Second argument must be a 6 x N matrix: particle array"); + /* ALLOCATE memory for the output array of the same size as the input */ + plhs[0] = mxDuplicateArray(prhs[1]); + r_in = mxGetDoubles(plhs[0]); + double *bcurr = malloc(sizeof(double)); + bcurr[0] = 0.0; + SliceMomentsPass(r_in,1,bcurr, 1.0,num_particles,Elem); + } + else if (nrhs == 0) { + /* list of required fields */ + plhs[0] = mxCreateCellMatrix(6,1); + mxSetCell(plhs[0],0,mxCreateString("_means")); + mxSetCell(plhs[0],1,mxCreateString("_stds")); + mxSetCell(plhs[0],2,mxCreateString("_weights")); + mxSetCell(plhs[0],3,mxCreateString("startturn")); + mxSetCell(plhs[0],4,mxCreateString("endturn")); + mxSetCell(plhs[0],5,mxCreateString("nslice")); + } + else { + mexErrMsgIdAndTxt("AT:WrongArg","Needs 2 or 0 arguments"); + } +} +#endif diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 6d2decc88..905d9916d 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -474,6 +474,12 @@ class BeamMoments(Element): """Element to compute bunches mean and std""" def __init__(self, family_name: str, **kwargs): + """ + Args: + family_name: Name of the element + + Default PassMethod: ``BeamMomentsPass`` + """ kwargs.setdefault('PassMethod', 'BeamMomentsPass') self._stds = numpy.zeros((6, 1, 1), order='F') self._means = numpy.zeros((6, 1, 1), order='F') @@ -485,13 +491,118 @@ def set_buffers(self, nturns, nbunch): @property def stds(self): + """Beam 6d standard deviation""" return self._stds @property def means(self): + """Beam 6d center of mass""" return self._means +class SliceMoments(Element): + """Element to compute slices mean and std""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['nslice'] + _conversions = dict(Element._conversions, nslice=int) + + def __init__(self, family_name: str, nslice: int, **kwargs): + """ + Args: + family_name: Name of the element + nslice: Number of slices + + Keyword arguments: + startturn: Start turn of the acquisition (Default 0) + endturn: End turn of the acquisition (Default 1) + + Default PassMethod: ``SliceMomentsPass`` + """ + kwargs.setdefault('PassMethod', 'SliceMomentsPass') + self._startturn = kwargs.pop('startturn', 0) + self._endturn = kwargs.pop('endturn', 1) + super(SliceMoments, self).__init__(family_name, nslice=nslice, + **kwargs) + self._nbunch = 1 + self.startturn = self._startturn + self.endturn = self._endturn + self._dturns = self.endturn - self.startturn + self._stds = numpy.zeros((3, nslice, self._dturns), order='F') + self._means = numpy.zeros((3, nslice, self._dturns), order='F') + self._spos = numpy.zeros((nslice, self._dturns), order='F') + self._weights = numpy.zeros((nslice, self._dturns), order='F') + self.set_buffers(self._endturn, 1) + + def set_buffers(self, nturns, nbunch): + self.endturn = min(self.endturn, nturns) + self._dturns = self.endturn - self.startturn + self._nbunch = nbunch + self._stds = numpy.zeros((3, nbunch*self.nslice, self._dturns), + order='F') + self._means = numpy.zeros((3, nbunch*self.nslice, self._dturns), + order='F') + self._spos = numpy.zeros((nbunch*self.nslice, self._dturns), + order='F') + self._weights = numpy.zeros((nbunch*self.nslice, self._dturns), + order='F') + + @property + def stds(self): + """Slices x,y,dp standard deviation""" + return self._stds.reshape((3, self._nbunch, + self.nslice, + self._dturns)) + + @property + def means(self): + """Slices x,y,dp center of mass""" + return self._means.reshape((3, self._nbunch, + self.nslice, + self._dturns)) + + @property + def spos(self): + """Slices s position""" + return self._spos.reshape((self._nbunch, + self.nslice, + self._dturns)) + + @property + def weights(self): + """Slices weights in mA if beam current >0, + otherwise fraction of total number of + particles in the bunch + """ + return self._weights.reshape((self._nbunch, + self.nslice, + self._dturns)) + + @property + def startturn(self): + """Start turn of the acquisition""" + return self._startturn + + @startturn.setter + def startturn(self, value): + if value < 0: + raise ValueError('startturn must be greater or equal to 0') + if value >= self._endturn: + raise ValueError('startturn must be smaller than endturn') + self._startturn = value + + @property + def endturn(self): + """End turn of the acquisition""" + return self._endturn + + @endturn.setter + def endturn(self, value): + if value <= 0: + raise ValueError('endturn must be greater than 0') + if value <= self._startturn: + raise ValueError('endturn must be greater than startturn') + self._endturn = value + + class Aperture(Element): """Aperture element""" _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['Limits'] @@ -777,7 +888,8 @@ def __init__(self, family_name: str, length: float, KickAngle: Correction deviation angles (H, V) FieldScaling: Scaling factor applied to the magnetic field - Available PassMethods: :ref:`BndMPoleSymplectic4Pass`, :ref:`BendLinearPass`, + Available PassMethods: :ref:`BndMPoleSymplectic4Pass`, + :ref:`BendLinearPass`, :ref:`ExactSectorBendPass`, :ref:`ExactRectangularBendPass`, :ref:`ExactRectBendPass`, BndStrMPoleSymplectic4Pass @@ -1004,7 +1116,7 @@ def __init__(self, family_name: str, betax: float = 1.0, """ Args: family_name: Name of the element - + Optional Args: betax: Horizontal beta function at element [m] betay: Vertical beta function at element [m] @@ -1015,14 +1127,14 @@ def __init__(self, family_name: str, betax: float = 1.0, tauy: Vertical damping time [turns] tauz: Longitudinal damping time [turns] U0: Energy Loss [eV] - + Default PassMethod: ``SimpleQuantDiffPass`` """ kwargs.setdefault('PassMethod', self.default_pass[True]) - + assert taux >= 0.0, 'taux must be greater than or equal to 0' self.taux = taux - + assert tauy >= 0.0, 'tauy must be greater than or equal to 0' self.tauy = tauy @@ -1033,24 +1145,23 @@ def __init__(self, family_name: str, betax: float = 1.0, self.emitx = emitx if emitx > 0.0: assert taux > 0.0, 'if emitx is given, taux must be non zero' - + assert emity >= 0.0, 'emity must be greater than or equal to 0' self.emity = emity if emity > 0.0: assert tauy > 0.0, 'if emity is given, tauy must be non zero' - + assert espread >= 0.0, 'espread must be greater than or equal to 0' self.espread = espread if espread > 0.0: assert tauz > 0.0, 'if espread is given, tauz must be non zero' - + self.U0 = U0 self.betax = betax self.betay = betay super(SimpleQuantDiff, self).__init__(family_name, **kwargs) - class Corrector(LongElement): """Corrector element""" _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['KickAngle'] @@ -1168,6 +1279,7 @@ def __init__(self, family_name: str, energy_loss: float, **kwargs): kwargs.setdefault('PassMethod', self.default_pass[False]) super().__init__(family_name, EnergyLoss=energy_loss, **kwargs) + Radiative.register(EnergyLoss) diff --git a/pyat/at/tracking/utils.py b/pyat/at/tracking/utils.py index 4ec026c05..70708f2db 100644 --- a/pyat/at/tracking/utils.py +++ b/pyat/at/tracking/utils.py @@ -27,8 +27,9 @@ def _set_beam_monitors(ring: Sequence[Element], nbunch: int, nturns: int): """Function to initialize the beam monitors""" monitors = list(refpts_iterator(ring, elements.BeamMoments)) + monitors += list(refpts_iterator(ring, elements.SliceMoments)) for m in monitors: - m.set_buffers(nturns, nbunch) + m.set_buffers(nturns, nbunch) return len(monitors) == 0