BreastCancerMetastasis-prediction / model_4D_CNN.py
model_4D_CNN.py
Raw
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import concat
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import regularizers
import numpy as np
import pickle
import json
import os
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.random.set_seed(1)
np.random.seed(1)
##############################################
p_node = True
hybrid = True
n_batch = 24
patches_approach = True
###########################################

#Model architecture
class MyModel(Model):
    def __init__(self, InputSize, sizeFilter, dropOutBase, activation, l2Factor, ConvDepth, numfilter):
        super(MyModel, self).__init__()
        self.depth = ConvDepth
        self.conv1 = layers.Conv3D(numfilter[0], sizeFilter, input_shape=(
            InputSize, InputSize, InputSize, 3), activation=activation, kernel_regularizer=regularizers.l2(l2Factor), padding='same')
        # self.pool1 = layers.MaxPool3D(pool_size=2, strides=2)
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv3D(
            numfilter[1], sizeFilter, activation=activation, kernel_regularizer=regularizers.l2(l2Factor), padding='same')
        self.pool2 = layers.MaxPool3D(pool_size=2, strides=2)
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv3D(
            numfilter[2], sizeFilter, activation=activation, kernel_regularizer=regularizers.l2(l2Factor), padding='same')
        self.pool3 = layers.MaxPool3D(pool_size=2, strides=2)
        self.bn3 = layers.BatchNormalization()

        self.conv4 = layers.Conv3D(
            numfilter[3], sizeFilter, activation=activation, kernel_regularizer=regularizers.l2(l2Factor), padding='same')
        self.pool4 = layers.MaxPool3D(pool_size=2, strides=2)
        self.bn4 = layers.BatchNormalization()
        self.pool5 = layers.MaxPool3D(pool_size=3, strides=3)
        self.conv5 = layers.Conv3D(numfilter[4], sizeFilter, activation=activation,
                                   kernel_regularizer=regularizers.l2(l2Factor), padding='same')
        self.bn5 = layers.BatchNormalization()

        self.conv51 = layers.Conv3D(
            numfilter[5], sizeFilter, activation=activation, kernel_regularizer=regularizers.l2(l2Factor), padding='same')
        self.bn51 = layers.BatchNormalization()

        self.flat = layers.Flatten()
        self.drop1 = layers.Dropout(dropOutBase)
        self.bn6 = layers.BatchNormalization()
        self.d1 = layers.Dense(16, activation=activation,
                               kernel_regularizer=regularizers.l2(l2Factor))

        self.drop2 = layers.Dropout(dropOutBase)
        self.bn7 = layers.BatchNormalization()
        self.d2 = layers.Dense(8, activation=activation,
                               kernel_regularizer=regularizers.l2(l2Factor))

        self.drop3 = layers.Dropout(dropOutBase)
        self.bn8 = layers.BatchNormalization()
        self.d3 = layers.Dense(2, kernel_regularizer=regularizers.l2(l2Factor))
        self.sm = layers.Softmax()

    def call(self, inputs, training=False):
        x = inputs[0] if hybrid else inputs
        x = self.bn1(self.conv1(x))
        x = self.bn2(self.pool2(self.conv2(x)))
        if self.depth > 2:
            x = self.bn3(self.pool3(self.conv3(x)))
        if self.depth > 3:
            x = self.bn4(self.pool4(self.conv4(x)))
        if self.depth > 4:
            x = self.bn5(self.conv5(x))
        if self.depth > 5:
            x = self.bn51(self.conv51(x))
        x = self.flat(self.pool5(x))
        if hybrid:
            x = concat([x, inputs[1]], axis=1)
        if training:
            x = self.d1(self.bn6(self.drop1(x)))
            x = self.d2(self.bn7(self.drop2(x)))
            x = self.d3(self.bn8(self.drop3(x)))
        else:
            x = self.d1(self.bn6(x))
            x = self.d2(self.bn7(x))
            x = self.d3(self.bn8(x))
        return self.sm(x)


 #loading model's information and weights
path_model_json = 'saved_model/activation-sigmoid_dropout-0.00_sizeFilter-5_l2factor-0.01_epochs-301616178465.json'
path_model_weight = 'saved_model/weights0.pickle'
path_MRI = 'data_MRI'
path_clinical = 'data_clinical'

#loading data
with open(path_MRI,'rb') as handle:
    data_MRI = np.load(handle)
with open(path_clinical,'rb') as handle:
    data_clinical = np.load(handle)

data = [tf.convert_to_tensor(data_MRI), tf.convert_to_tensor(data_clinical)]
    
#loading saved model configuration
with open(path_model_json, 'rb') as handle:
    model_info = json.load(handle)
model = MyModel(100,model_info['sizeFilter'],model_info['dropout'],model_info['activation'],model_info['l2factor'],model_info['depth'],model_info['NumFilters'])

#initial weights by passing data
model(data)
#loading saved weight to model
with open(path_model_weight, 'rb') as handle:
    tem_weight = pickle.load(handle)

for k in range(len(model.trainable_variables)):
    model.trainable_variables[k].assign(tem_weight[k])
    
#making prediction    
predictions = model(data)