Skip to content Skip to sidebar Skip to footer

Overriding Keras Predict Function

I have Keras model that accepts inputs which have 4D shapes as (n, height, width, channel). However, my data generator is producing 2D arrays as(n, width*height). So, the predict f

Solution 1:

You don't override the predict, you simply add a Reshape layer at the beginning of your model.

With the functional API:

from keras.layers import *

inp = Input((width*heigth,))
first = Reshape((width,height,1))(inp)

..... other layers.....

model = Model(inp, outputFromTheLastLayer)    

With a sequential model:

model = Sequential()    
model.add(Reshape((width,height,1), input_shape = (width*height,)))
model.add(otherlayers)   

About the output shape.

Since you have 5 outputs, you need your targe array to be a list of five arrays:

raw_train_target = [target1,target2,target3,target4,target5]

If you cannot do that, and raw_train_target is one single arary with the targets all following a sequence, you can try to use a concatenate layer at the end:

output = Concatenate()(outputs)     

Post a Comment for "Overriding Keras Predict Function"