Skip to content

Commit

Permalink
Add --noremove option to tiling jobs. Pull out alignment code check i…
Browse files Browse the repository at this point in the history
…nto a separate function (#63)

* add --noremove option

* build RIOS and workaround single tile cases

* Dockerfile fixes

* remove RIOS. Pull out alignment checks into separate function. Fix array job syntax

* workaround AWS_BATCH_JOB_ARRAY_INDEX not being set if one tile

* tidy

* typo

* list

* AWS resets AWS_BATCH_JOB_ARRAY_INDEX so use a command line arg instead
  • Loading branch information
gillins authored Jun 27, 2024
1 parent 34b31ac commit 1822bba
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 67 deletions.
16 changes: 15 additions & 1 deletion parallel_examples/awsbatch/do_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def getCmdargs():
help="Maximum spectral difference for segmentation (default=%(default)s)")
p.add_argument("--spectDistPcntile", type=int, default=50, required=False,
help="Spectral Distance Percentile for segmentation (default=%(default)s)")
p.add_argument("--noremove", action="store_true", default=False,
help="don't remove files from S3 (for debugging)")

cmdargs = p.parse_args()
if cmdargs.bands is not None:
Expand Down Expand Up @@ -116,10 +118,20 @@ def main():
'--minSegmentSize', str(cmdargs.minSegmentSize),
'--maxSpectDiff', cmdargs.maxSpectDiff,
'--spectDistPcntile', str(cmdargs.spectDistPcntile)]}

arrayProperties = {}
if len(colRowList) > 1:
# throws error if this is 1...
arrayProperties['size'] = len(colRowList)
else:
# must fake AWS_BATCH_JOB_ARRAY_INDEX
# can't set this as and env var as Batch overrides
containerOverrides['command'].extend(['--arrayindex', '0'])

response = batch.submit_job(jobName="pyshepseg_tiles",
jobQueue=cmdargs.jobqueue,
jobDefinition=cmdargs.jobdefntile,
arrayProperties={'size': len(colRowList)},
arrayProperties=arrayProperties,
containerOverrides=containerOverrides)
tilesJobId = response['jobId']
print('Tiles Job Id', tilesJobId)
Expand All @@ -137,6 +149,8 @@ def main():
cmd.extend(['--spatialstats', cmdargs.spatialstats])
if cmdargs.nogdalstats:
cmd.append('--nogdalstats')
if cmdargs.noremove:
cmd.append('--noremove')

response = batch.submit_job(jobName="pyshepseg_stitch",
jobQueue=cmdargs.jobqueue,
Expand Down
36 changes: 20 additions & 16 deletions parallel_examples/awsbatch/do_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def getCmdargs():
p.add_argument("--nogdalstats", action="store_true", default=False,
help="don't calculate GDAL's statistics or write a colour table. " +
"Can't be used with --stats.")
p.add_argument("--noremove", action="store_true", default=False,
help="don't remove files from S3 (for debugging)")

cmdargs = p.parse_args()

Expand Down Expand Up @@ -94,15 +96,16 @@ def main():
cmdargs.overlapsize, tempDir, writeHistogram=True)

# clean up files to release space
objs = []
for col, row in tileFilenames:
filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, col, row, 'tif')
objs.append({'Key': filename})

# workaround 1000 at a time limit
while len(objs) > 0:
s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs[0:1000]})
del objs[0:1000]
if not cmdargs.noremove:
objs = []
for col, row in tileFilenames:
filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, col, row, 'tif')
objs.append({'Key': filename})

# workaround 1000 at a time limit
while len(objs) > 0:
s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs[0:1000]})
del objs[0:1000]

if not cmdargs.nogdalstats:
band = localDs.GetRasterBand(1)
Expand Down Expand Up @@ -155,13 +158,14 @@ def main():
s3.upload_file(localOutfile, cmdargs.bucket, cmdargs.outfile)

# cleanup temp files from S3
objs = [{'Key': cmdargs.pickle}]
if cmdargs.stats is not None:
objs.append({'Key': statsKey})
if cmdargs.spatialstats is not None:
objs.append({'Key': spatialstatsKey})

s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs})
if not cmdargs.noremove:
objs = [{'Key': cmdargs.pickle}]
if cmdargs.stats is not None:
objs.append({'Key': statsKey})
if cmdargs.spatialstats is not None:
objs.append({'Key': spatialstatsKey})

s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs})

# cleanup
shutil.rmtree(tempDir)
Expand Down
17 changes: 9 additions & 8 deletions parallel_examples/awsbatch/do_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@

gdal.UseExceptions()

# set by AWS Batch
ARRAY_INDEX = os.getenv('AWS_BATCH_JOB_ARRAY_INDEX')
if ARRAY_INDEX is None:
raise SystemExit('Must set AWS_BATCH_JOB_ARRAY_INDEX env var')

ARRAY_INDEX = int(ARRAY_INDEX)


def getCmdargs():
"""
Expand All @@ -48,9 +41,17 @@ def getCmdargs():
help="Maximum spectral difference for segmentation (default=%(default)s)")
p.add_argument("--spectDistPcntile", type=int, default=50, required=False,
help="Spectral Distance Percentile for segmentation (default=%(default)s)")
p.add_argument("--arrayindex", type=int,
help="Override AWS_BATCH_JOB_ARRAY_INDEX env var")

cmdargs = p.parse_args()

if cmdargs.arrayindex is None:
cmdargs.arrayindex = os.getenv('AWS_BATCH_JOB_ARRAY_INDEX')
if cmdargs.arrayindex is None:
raise SystemExit('Must set AWS_BATCH_JOB_ARRAY_INDEX env var or ' +
'specify --arrayindex')

return cmdargs


Expand All @@ -75,7 +76,7 @@ def main():
tempDir = tempfile.mkdtemp()

# work out which tile we are processing
col, row = dataFromPickle['colRowList'][ARRAY_INDEX]
col, row = dataFromPickle['colRowList'][cmdargs.arrayindex]

# work out a filename to save with the output of this tile
# Note: this filename format is repeated in do_stitch.py
Expand Down
4 changes: 4 additions & 0 deletions parallel_examples/awsbatch/submit-pyshepseg-job.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def getCmdargs():
help="Maximum spectral difference for segmentation (default=%(default)s)")
p.add_argument("--spectDistPcntile", type=int, default=50, required=False,
help="Spectral Distance Percentile for segmentation (default=%(default)s)")
p.add_argument("--noremove", action="store_true", default=False,
help="don't remove files from S3 (for debugging)")

cmdargs = p.parse_args()

Expand Down Expand Up @@ -98,6 +100,8 @@ def main():
cmd.append('--nogdalstats')
if cmdargs.tileprefix is not None:
cmd.extend(['--tileprefix', cmdargs.tileprefix])
if cmdargs.noremove:
cmd.append('--noremove')

# submit the prepare job
response = batch.submit_job(jobName="pyshepseg_prepare",
Expand Down
98 changes: 56 additions & 42 deletions pyshepseg/tilingstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,8 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile,
valid pixels (not nodata) that were used to calculate the statistics.
"""
segds = segfile
if not isinstance(segds, gdal.Dataset):
segds = gdal.Open(segfile, gdal.GA_Update)
segband = segds.GetRasterBand(1)

imgds = imgfile
if not isinstance(imgds, gdal.Dataset):
imgds = gdal.Open(imgfile, gdal.GA_ReadOnly)
imgband = imgds.GetRasterBand(imgbandnum)
if (imgband.DataType == gdal.GDT_Float32 or
imgband.DataType == gdal.GDT_Float64):
raise PyShepSegStatsError("Float image types not supported")

if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize:
raise PyShepSegStatsError("Images must be same size")

if segds.GetGeoTransform() != imgds.GetGeoTransform():
raise PyShepSegStatsError("Images must have same spatial extent and pixel size")

if not equalProjection(segds.GetProjection(), imgds.GetProjection()):
raise PyShepSegStatsError("Images must be in the same projection")
segds, segband, imgds, imgband = doImageAlignmentChecks(segfile,
imgfile, imgbandnum)

attrTbl = segband.GetDefaultRAT()
existingColNames = [attrTbl.GetNameOfCol(i)
Expand Down Expand Up @@ -184,6 +165,58 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile,
raise PyShepSegStatsError('Not all pixels found during processing')


def doImageAlignmentChecks(segfile, imgfile, imgbandnum):
"""
Do the checks that the segment file and image file that is being used to
collect the stats actually align. We refuse to process the files if they
don't as it is not clear how they should be made to line up - this is up
to the user to get right. Also checks that imgfile is not a float image.
Parameters
----------
segfile : str or gdal.Dataset
Path to segmented file or an open GDAL dataset.
imgfile : string
Path to input file for collecting statistics from
imgbandnum : int
1-based index of the band number in imgfile to use for collecting stats
Returns
-------
segds: gdal.Dataset
Opened GDAL datset for the segments file
segband: gdal.Band
First Band of the segds
imgds: gdal.Dataset
Opened GDAL dataset for the image data file
imgband: gdal.Band
Requested band for the imgds
"""
segds = segfile
if not isinstance(segds, gdal.Dataset):
segds = gdal.Open(segfile, gdal.GA_Update)
segband = segds.GetRasterBand(1)

imgds = imgfile
if not isinstance(imgds, gdal.Dataset):
imgds = gdal.Open(imgfile, gdal.GA_ReadOnly)
imgband = imgds.GetRasterBand(imgbandnum)
if (imgband.DataType == gdal.GDT_Float32 or
imgband.DataType == gdal.GDT_Float64):
raise PyShepSegStatsError("Float image types not supported")

if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize:
raise PyShepSegStatsError("Images must be same size")

if segds.GetGeoTransform() != imgds.GetGeoTransform():
raise PyShepSegStatsError("Images must have same spatial extent and pixel size")

if not equalProjection(segds.GetProjection(), imgds.GetProjection()):
raise PyShepSegStatsError("Images must be in the same projection")

return segds, segband, imgds, imgband


@njit
def accumulateSegDict(segDict, noDataDict, imgNullVal, tileSegments, tileImageData):
"""
Expand Down Expand Up @@ -1028,28 +1061,9 @@ def calcPerSegmentSpatialStatsTiled(imgfile, imgbandnum, segfile,
The value to fill in for segments that have no data.
"""
segds = segfile
if not isinstance(segds, gdal.Dataset):
segds = gdal.Open(segfile, gdal.GA_Update)
segband = segds.GetRasterBand(1)
segds, segband, imgds, imgband = doImageAlignmentChecks(segfile,
imgfile, imgbandnum)

imgds = imgfile
if not isinstance(imgds, gdal.Dataset):
imgds = gdal.Open(imgfile, gdal.GA_ReadOnly)
imgband = imgds.GetRasterBand(imgbandnum)
if (imgband.DataType == gdal.GDT_Float32 or
imgband.DataType == gdal.GDT_Float64):
raise PyShepSegStatsError("Float image types not supported")

if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize:
raise PyShepSegStatsError("Images must be same size")

if segds.GetGeoTransform() != imgds.GetGeoTransform():
raise PyShepSegStatsError("Images must have same spatial extent and pixel size")

if not equalProjection(segds.GetProjection(), imgds.GetProjection()):
raise PyShepSegStatsError("Images must be in the same projection")

attrTbl = segband.GetDefaultRAT()
existingColNames = [attrTbl.GetNameOfCol(i)
for i in range(attrTbl.GetColumnCount())]
Expand Down

0 comments on commit 1822bba

Please sign in to comment.