Skip to content Skip to sidebar Skip to footer

Geodesic Distance Transform In Python

In python there is the distance_transform_edt function in the scipy.ndimage.morphology module. I applied it to a simple case, to compute the distance from a single cell in a masked

Solution 1:

First of all, thumbs up for a very clear and well written question.

There is a very good and fast implementation of a Fast Marching method called scikit-fmm to solve this kind of problem. You can find the details here: http://pythonhosted.org//scikit-fmm/

Installing it might be the hardest part, but on Windows with Conda its easy, since there is 64bit Conda package for Py27: https://binstar.org/jmargeta/scikit-fmm

From there on, just pass your masked array to it, as you do with your own function. Like:

distance = skfmm.distance(m)

The results looks similar, and i think even slightly better. Your approach searches (apparently) in eight distinct directions resulting in a bit of a 'octagonal-shaped` distance.

enter image description here

On my machine the scikit-fmm implementation is over 200x faster then your function.

enter image description here

Solution 2:

64-bit Windows binaries for scikit-fmm are now available from Christoph Gohlke.

http://www.lfd.uci.edu/~gohlke/pythonlibs/#scikit-fmm

Solution 3:

A slightly faster (about 10x) implementation that achieves the same result as your geodesic_distance_transform:

def getMissingMask(slab):

    nan_mask=numpy.where(numpy.isnan(slab),1,0)
    if not hasattr(slab,'mask'):
        mask_mask=numpy.zeros(slab.shape)
    else:
        if slab.mask.size==1 and slab.mask==False:
            mask_mask=numpy.zeros(slab.shape)
        else:
            mask_mask=numpy.where(slab.mask,1,0)
    mask=numpy.where(mask_mask+nan_mask>0,1,0)

    return mask

def geodesic(img,seed):

    seedy,seedx=seed
    mask=getMissingMask(img)

    #----Call distance_transform_edt if no missing----
    if mask.sum()==0:
        slab=numpy.ones(img.shape)
        slab[seedy,seedx]=0
        return distance_transform_edt(slab)

    target=(1-mask).sum()
    dist=numpy.ones(img.shape)*numpy.inf
    dist[seedy,seedx]=0

    def expandDir(img,direction):
        if direction=='n':
            l1=img[0,:]
            img=numpy.roll(img,1,axis=0)
            img[0,:]==l1
        elif direction=='s':
            l1=img[-1,:]
            img=numpy.roll(img,-1,axis=0)
            img[-1,:]==l1
        elif direction=='e':
            l1=img[:,0]
            img=numpy.roll(img,1,axis=1)
            img[:,0]=l1
        elif direction=='w':
            l1=img[:,-1]
            img=numpy.roll(img,-1,axis=1)
            img[:,-1]==l1
        elif direction=='ne':
            img=expandDir(img,'n')
            img=expandDir(img,'e')
        elif direction=='nw':
            img=expandDir(img,'n')
            img=expandDir(img,'w')
        elif direction=='sw':
            img=expandDir(img,'s')
            img=expandDir(img,'w')
        elif direction=='se':
            img=expandDir(img,'s')
            img=expandDir(img,'e')

        return img

    def expandIter(img):
        sqrt2=numpy.sqrt(2)
        tmps=[]
        for dirii,dd in zip(['n','s','e','w','ne','nw','sw','se'],\
                [1,]*4+[sqrt2,]*4):
            tmpii=expandDir(img,dirii)+dd
            tmpii=numpy.minimum(tmpii,img)
            tmps.append(tmpii)
        img=reduce(lambda x,y:numpy.minimum(x,y),tmps)

        return img

    #----------------Iteratively expand----------------
    dist_old=dist
    while True:
        expand=expandIter(dist)
        dist=numpy.where(mask,dist,expand)
        nc=dist.size-len(numpy.where(dist==numpy.inf)[0])

        if nc>=target or numpy.all(dist_old==dist):
            break
        dist_old=dist

    return dist

Also note that if the mask forms more than 1 connected regions (e.g. adding another circle not touching the others), your function will fall into an endless loop.

UPDATE:

I found one Cython implementation of Fast Sweeping method in this notebook, which can be used to achieve the same result as scikit-fmm with probably comparable speed. One just need to feed a binary flag matrix (with 1s as viable points, inf otherwise) as the cost to the GDT() function.

Post a Comment for "Geodesic Distance Transform In Python"