diff --git a/parallel_examples/awsbatch/Dockerfile b/parallel_examples/awsbatch/Dockerfile index f634f50..78b499f 100644 --- a/parallel_examples/awsbatch/Dockerfile +++ b/parallel_examples/awsbatch/Dockerfile @@ -37,18 +37,9 @@ RUN cd /tmp \ && make install \ && cd ../.. \ && rm -rf kealib-${KEALIB_VERSION} kealib-${KEALIB_VERSION}.tar.gz - -ENV RIOS_VERSION=2.0.3 -RUN cd /tmp \ - && wget -q https://github.com/ubarsc/rios/releases/download/rios-${RIOS_VERSION}/rios-${RIOS_VERSION}.tar.gz \ - && tar xf rios-${RIOS_VERSION}.tar.gz \ - && cd rios-${RIOS_VERSION} \ - && DEB_PYTHON_INSTALL_LAYOUT=deb_system pip install . \ - && cd .. \ - && rm -rf rios-${RIOS_VERSION} rios-${RIOS_VERSION}.tar.gz COPY pyshepseg-$PYSHEPSEG_VER.tar.gz /tmp -# install pyshegseg +# install RIOS RUN cd /tmp && tar xf pyshepseg-$PYSHEPSEG_VER.tar.gz \ && cd pyshepseg-$PYSHEPSEG_VER \ && DEB_PYTHON_INSTALL_LAYOUT=deb_system pip install . \ @@ -82,9 +73,9 @@ RUN apt-get autoremove -y && apt-get clean && rm -rf /var/lib/apt/lists/* USER $SERVICEUSER # a few quick tests +#RUN gdal_translate --formats | grep KEA RUN python3 -c 'from osgeo import gdal;assert(gdal.GetDriverByName("KEA") is not None)' RUN python3 -c 'from pyshepseg import tiling' -RUN python3 -c 'from rios import applier' # export the volume VOLUME $SW_VOLUME diff --git a/parallel_examples/awsbatch/do_prepare.py b/parallel_examples/awsbatch/do_prepare.py index 72a81b6..9dc9b33 100755 --- a/parallel_examples/awsbatch/do_prepare.py +++ b/parallel_examples/awsbatch/do_prepare.py @@ -111,7 +111,7 @@ def main(): # now submit an array job with all the tiles # (can't do this before now because we don't know how many tiles) - arrayProperties = None + arrayProperties = {} if len(colRowList) > 1: # throws error if this is 1... arrayProperties = {'size': len(colRowList)} diff --git a/pyshepseg/tilingstats.py b/pyshepseg/tilingstats.py index 7461c81..eb51f14 100644 --- a/pyshepseg/tilingstats.py +++ b/pyshepseg/tilingstats.py @@ -111,27 +111,8 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile, valid pixels (not nodata) that were used to calculate the statistics. """ - segds = segfile - if not isinstance(segds, gdal.Dataset): - segds = gdal.Open(segfile, gdal.GA_Update) - segband = segds.GetRasterBand(1) - - imgds = imgfile - if not isinstance(imgds, gdal.Dataset): - imgds = gdal.Open(imgfile, gdal.GA_ReadOnly) - imgband = imgds.GetRasterBand(imgbandnum) - if (imgband.DataType == gdal.GDT_Float32 or - imgband.DataType == gdal.GDT_Float64): - raise PyShepSegStatsError("Float image types not supported") - - if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize: - raise PyShepSegStatsError("Images must be same size") - - if segds.GetGeoTransform() != imgds.GetGeoTransform(): - raise PyShepSegStatsError("Images must have same spatial extent and pixel size") - - if not equalProjection(segds.GetProjection(), imgds.GetProjection()): - raise PyShepSegStatsError("Images must be in the same projection") + segds, segband, imgds, imgband = doImageAlignmentChecks(segfile, + imgfile, imgbandnum) attrTbl = segband.GetDefaultRAT() existingColNames = [attrTbl.GetNameOfCol(i) @@ -184,6 +165,58 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile, raise PyShepSegStatsError('Not all pixels found during processing') +def doImageAlignmentChecks(segfile, imgfile, imgbandnum): + """ + Do the checks that the segment file and image file that is being used to + collect the stats actually align. We refuse to process the files if they + don't as it is not clear how they should be made to line up - this is up + to the user to get right. Also checks that imgfile is not a float image. + + Parameters + ---------- + segfile : str or gdal.Dataset + Path to segmented file or an open GDAL dataset. + imgfile : string + Path to input file for collecting statistics from + imgbandnum : int + 1-based index of the band number in imgfile to use for collecting stats + + Returns + ------- + segds: gdal.Dataset + Opened GDAL datset for the segments file + segband: gdal.Band + First Band of the segds + imgds: gdal.Dataset + Opened GDAL dataset for the image data file + imgband: gdal.Band + Requested band for the imgds + """ + segds = segfile + if not isinstance(segds, gdal.Dataset): + segds = gdal.Open(segfile, gdal.GA_Update) + segband = segds.GetRasterBand(1) + + imgds = imgfile + if not isinstance(imgds, gdal.Dataset): + imgds = gdal.Open(imgfile, gdal.GA_ReadOnly) + imgband = imgds.GetRasterBand(imgbandnum) + if (imgband.DataType == gdal.GDT_Float32 or + imgband.DataType == gdal.GDT_Float64): + raise PyShepSegStatsError("Float image types not supported") + + if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize: + raise PyShepSegStatsError("Images must be same size") + + if segds.GetGeoTransform() != imgds.GetGeoTransform(): + raise PyShepSegStatsError("Images must have same spatial extent and pixel size") + + if not equalProjection(segds.GetProjection(), imgds.GetProjection()): + raise PyShepSegStatsError("Images must be in the same projection") + + return segds, segband, imgds, imgband + + @njit def accumulateSegDict(segDict, noDataDict, imgNullVal, tileSegments, tileImageData): """ @@ -1028,28 +1061,9 @@ def calcPerSegmentSpatialStatsTiled(imgfile, imgbandnum, segfile, The value to fill in for segments that have no data. """ - segds = segfile - if not isinstance(segds, gdal.Dataset): - segds = gdal.Open(segfile, gdal.GA_Update) - segband = segds.GetRasterBand(1) + segds, segband, imgds, imgband = doImageAlignmentChecks(segfile, + imgfile, imgbandnum) - imgds = imgfile - if not isinstance(imgds, gdal.Dataset): - imgds = gdal.Open(imgfile, gdal.GA_ReadOnly) - imgband = imgds.GetRasterBand(imgbandnum) - if (imgband.DataType == gdal.GDT_Float32 or - imgband.DataType == gdal.GDT_Float64): - raise PyShepSegStatsError("Float image types not supported") - - if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize: - raise PyShepSegStatsError("Images must be same size") - - if segds.GetGeoTransform() != imgds.GetGeoTransform(): - raise PyShepSegStatsError("Images must have same spatial extent and pixel size") - - if not equalProjection(segds.GetProjection(), imgds.GetProjection()): - raise PyShepSegStatsError("Images must be in the same projection") - attrTbl = segband.GetDefaultRAT() existingColNames = [attrTbl.GetNameOfCol(i) for i in range(attrTbl.GetColumnCount())]