Source code for pipeline.pipeline

from monai.networks.nets import SwinUNETR, UNETR
from monai.data import load_decathlon_datalist, CacheDataset, decollate_batch, DataLoader
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureTyped,
    Activationsd,
    Invertd,
    AsDiscreted,
    SaveImaged,
    KeepLargestConnectedComponentd,
)
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
import torch
import re
import nibabel as nib
import os
from tqdm import tqdm
import json
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import math
from monai import data


[docs] class Pipeline: """ Class for managing machine learning pipeline for medical image semantic segmentation. It assists with loading models and data for training, and it automatically records metrics and save check points. Attributes: debug_mode(bool): Whether the pipeline is in debug mode or not model: The neural network used for training or inference. model_type (str): Neural network architecture of model. Currently supports UNETR and SWINUNETR train_transforms: Transformations applied on the training dataset val_transforms: Transformations applied on the validation dataset modality (int): Input dimension of the loaded dataset num_of_labels (int): Number of output classes of the dataset dataset_name (str): Name of the dataset num_train_images (int): Number of training images in the dataset num_val_images (int): Number of validation images in the dataset train_batch_size (int): Batchsize for training """ def __init__(self, model_type: str, modality: int, num_of_labels: int, model_path: str = "", debug: bool = False): """ Parent constructor for model prediction. Defines the model type that is used as well as the paths for loading the pretrained model, loading and saving the data Args: model: The model that is going to be used for predictions. Should be monai UNETR or SwinUNETR. model_path (str): The path to the pretrained model as a string. Should include the model .pth file. debug (bool): Boolean that enables debug messages. Defaults to false to disable messages. """ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.debug_mode = debug self.create_model(model_type=model_type, modality=modality, num_of_labels=num_of_labels, model_path=model_path)
[docs] def create_model(self, model_type: str, modality: int, num_of_labels: int, model_path: str = "") -> None: """ Creates a new model for the pipeline Args: model_type (str): Type of model the pipeline uses, takes value "UNETR" or "SWINUNETR" for their respective model types. modality (int): Modality of the dataset. Determines the input dimension of the model. num_of_labels (int): Number of labels to the dataset. model_path (str): File path to the saved model of the same type as model_type. """ self.model_type = model_type if model_type == "UNETR": self.model = UNETR( in_channels=modality, out_channels=num_of_labels, img_size=(96, 96, 96), feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed="perceptron", norm_name="instance", res_block=True, dropout_rate=0.0 ).to(self.device) elif model_type == "SWINUNETR": self.model = SwinUNETR( img_size=(96, 96, 96), in_channels=modality, out_channels=num_of_labels, feature_size=48, use_checkpoint=True, ).to(self.device) try: if model_path == "": self.model.load_from(torch.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_swinvit.pt"))) except: self.__debug("Warning: Could not find model_swinvit.pt. It is best to initiate SwinUNETR " + "with self supervised pretrained model to reduce training time") else: raise Exception("Unexpected model type given") if model_path != "": self.model.load_state_dict(torch.load(model_path))
[docs] def load_model(self, model_path: str) -> None: """ Load the saved model. Args: model_path (str): File path to the saved model of the same type as model_type. """ self.model.load_state_dict(torch.load(model_path))
[docs] def load_data(self, dataset_path: str, train_transforms, val_transforms, cache_num_train: int = 24, train_batch_size: int = 1, cache_num_val: int = 6, val_batch_size: int = 1, workers: int = 4) -> None: """ Load the training and validation data from the json file for the dataset. Args: dataset_path (str): File path to the json file of the dataset. train_transforms: Transformation done on the dataset during training. val_transforms: Transformation done on the dataset during validation. cache_num_train (int): Number of cached data for training dataset. train_batch_size (int): Batch size for training. cache_num_val (int): Number of cached data for validation dataset. val_batch_size (int): Batch size for validation. workers (int): Number of workers working in parallel. """ datalist = load_decathlon_datalist(dataset_path, True, "training") val_files = load_decathlon_datalist(dataset_path, True, "validation") train_ds = CacheDataset( data=datalist, transform=train_transforms, cache_num=cache_num_train, cache_rate=1.0, num_workers=workers, ) train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers, pin_memory=True) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=cache_num_val, cache_rate=1.0, num_workers=workers) val_loader = DataLoader(val_ds, batch_size=val_batch_size, shuffle=False, num_workers=workers, pin_memory=True) self.train_transforms = train_transforms self.val_transforms = val_transforms f = open(dataset_path) json_data = json.load(f) self.modality = len(json_data['modality']) self.num_of_labels = len(json_data['labels']) self.dataset_name = json_data['name'] self.num_train_images = len(train_ds) self.num_val_images = len(val_ds) self.train_batch_size = train_batch_size self.val_loader = val_loader self.train_loader = train_loader
[docs] def train(self, max_epoch: int, epoch_val: int, learning_rate: float = 1e-4, weight_decay: float = 1e-5) -> None: """ Initiate training for the loaded model on the loaded dataset. Args: max_epoch: Total number of epoch to train. epoch_val: Number of epochs between every validation and saving the model learning_rate: learning rate of the training process with AdamW optimizer weight_decay: Weight decay for the AdamW optimizer """ torch.backends.cudnn.benchmark = True loss_function = DiceCELoss(to_onehot_y=True, softmax=True) optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=weight_decay) scaler = torch.cuda.amp.GradScaler() max_iterations = math.ceil(max_epoch * self.num_train_images / self.train_batch_size) eval_num = math.ceil(epoch_val * self.num_train_images / self.train_batch_size) post_label = AsDiscrete(to_onehot=self.num_of_labels) post_pred = AsDiscrete(argmax=True, to_onehot=self.num_of_labels) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) save_folder = self.model_type + self.dataset_name + str(datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) os.makedirs(save_folder) os.makedirs(save_folder + "/logs") writer = SummaryWriter(save_folder + '/logs/{}'.format(datetime.now().strftime("%Y_%m_%d"))) global_step = 0 dice_val_best = 0.0 global_step_best = 0 epoch_loss_values = [] metric_values = [] def __validation(epoch_iterator_val): self.model.eval() with torch.no_grad(): for batch in epoch_iterator_val: val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, self.model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list] # Calculate metrics dice_metric(y_pred=val_output_convert, y=val_labels_convert) epoch_iterator_val.set_description( "Validate (%d / %d Steps)" % (global_step, max_iterations)) # noqa: B038 mean_dice_val = dice_metric.aggregate().item() dice_metric.reset() return mean_dice_val def __train(global_step, train_loader, dice_val_best, global_step_best): self.model.train() epoch_loss = 0 step = 0 epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) with torch.cuda.amp.autocast(): logit_map = self.model(x) loss = loss_function(logit_map, y) scaler.scale(loss).backward() epoch_loss += loss.item() scaler.unscale_(optimizer) scaler.step(optimizer) scaler.update() optimizer.zero_grad() epoch_iterator.set_description( f"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})" ) if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations: epoch_iterator_val = tqdm(self.val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) dice_val = __validation(epoch_iterator_val) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) if dice_val > dice_val_best: dice_val_best = dice_val global_step_best = global_step torch.save(self.model.state_dict(), os.path.join(save_folder, self.model_type + self.dataset_name + str(global_step) + ".pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val) ) # Record metric with tensorboard writer.add_scalar("Dice Val", dice_val, global_step=global_step) writer.add_scalar("Dice Cross Entropy Loss", epoch_loss, global_step=global_step) global_step += 1 return global_step, dice_val_best, global_step_best while global_step < max_iterations: global_step, dice_val_best, global_step_best = __train(global_step, self.train_loader, dice_val_best, global_step_best) writer.close()
[docs] def inference(self, data_folder, output_folder, transforms) -> None: """ Runs the prediction on the files located under self.data_folder, will save the files as Nifti (.nii.gz) format under output_folder. If output_folder is not specified, then it will be saved to the folder where the data was originally taken from. Args: data_folder (str): The folder where the data is located as string. All files in this folder should be medical images. output_folder: The folder path to save the nifti images as a string. If None, then it will save to the folder where the data files are located. (self.data_folder) transforms: Transformations to apply onto images before inferece. Should be similar to transformation done on validation dataset """ self.inference_transforms = transforms self.file_dicts = [] self.files = [] self.__load_inference_dataset(data_folder) self.model.eval() counter = 0 with torch.no_grad(): for i, test_data in enumerate(self.val_loader_inference): # Make prediction img = test_data["image"].to(self.device) test_data["pred"] = sliding_window_inference(img, (96, 96, 96), 4, self.model, overlap=0.8) # Post-processing transforms # Source: https://github.com/MASILab/3DUX-Net/tree/14ea46b7b4c4980b46aba066aaaa24b1d9c1bb0d post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", softmax=True), Invertd( keys="pred", # invert the `pred` data field, also support multiple fields transform = self.inference_transforms, orig_keys="image", # get the previously applied pre_transforms information on the `img` data field, # then invert `pred` based on this information. we can use same info # for multiple fields, also support different orig_keys for different fields meta_keys="pred_meta_dict", # key field to save inverted meta data, every item maps to `keys` orig_meta_keys="image_meta_dict", # get the meta data from `img_meta_dict` field when inverting, # for example, may need the `affine` to invert `Spacingd` transform, # multiple fields can use the same meta data to invert meta_key_postfix="meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key, # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}", # otherwise, no need this arg during inverting nearest_interp=False, # don't change the interpolation mode to "nearest" when inverting transforms # to ensure a smooth output, then execute `AsDiscreted` transform to_tensor=True, # convert to PyTorch Tensor after inverting ), AsDiscreted(keys="pred", argmax=True), KeepLargestConnectedComponentd(keys='pred', applied_labels=[1, 3]), SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir=output_folder, output_postfix="temp", output_ext=".nii.gz", resample=True, separate_folder=False), ]) test_data = [post_transforms(j) for j in decollate_batch(test_data)] # Small modification to affine matrix self.__load_and_translate(output_folder=output_folder, file_name=self.files[counter]) counter += 1
def __load_inference_dataset(self, data_folder: str) -> None: """ Loads and preprocesses the data specified in under data_folder. Will save the data as a Monai Dataloader and apply the relevant transforms that were used for training. Args: data_folder: Path as a string to the folder where the medical images to be segmented are located. """ self.__load_files_from_folder(data_folder) test_dataset = data.Dataset(data=self.file_dicts, transform=self.inference_transforms) self.val_loader_inference = data.DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, ) def __load_files_from_folder(self, data_folder: str) -> None: """ Loads the files into a list of dictionaries to be read by Monai's built in dataset. This needs to be formatted in this specific way so the transforms can be properly applied (the transforms are expecting specific keys). The files that are loaded are all the files in the folder specified by data_folder. This is a mock of Monai's load_decathlon_datalist(). Args: data_folder: Path as a string to the folder where the medical images to be segmented are located. """ self.file_dicts.clear() self.files.clear() for root, dirs, files in os.walk(data_folder): for file in files: file_path = os.path.join(root, file) image_dict = { "image": file_path, } self.files.append(file) self.file_dicts.append(image_dict) def __debug(self, message: str) -> None: """ Debug print statements, allows debug messages to be sent if self.debug_mode is True. Args: message: The message that is sent """ if self.debug_mode: print(message) return None def __load_and_translate(self, output_folder, file_name) -> None: """ Helper function that loads the saved file from monai and applies the necessary affine matrix modifications to it, then deletes the temporary monai file and saves as the proper nifti file. Args: output_folder: The path to the folder that contains the temporary monai saved file. file_name: The name of the file that was being analyzed. """ temp_name = re.sub(r"\.nii\.gz$", "_temp.nii.gz", file_name) temp_file_path = os.path.join(output_folder, temp_name) seg_img = nib.load(temp_file_path) self.__debug(f"segm affine is {seg_img.affine}") self.__debug(f"segm shape is {seg_img.shape}") new_affine = seg_img.affine new_affine[:3, 3] = [0, 0, 0] new_affine[1, 1] = -1 * new_affine[1, 1] self.__debug(f"New affine is {new_affine}") new_name = re.sub(r"\.nii\.gz$", "-segmented.nii.gz", file_name) nib.save( nib.Nifti1Image(seg_img.get_fdata(), affine=new_affine), os.path.join(output_folder, new_name) ) if os.path.exists(temp_file_path): os.remove(temp_file_path) self.__debug(f"File '{temp_file_path}' deleted successfully") else: self.__debug(f"File '{temp_file_path}' does not exist")
if __name__ == "__main__": from monai.transforms import ( AsDiscrete, EnsureChannelFirstd, Compose, CropForegroundd, LoadImaged, Orientationd, RandFlipd, RandCropByPosNegLabeld, RandShiftIntensityd, ScaleIntensityRanged, Spacingd, RandRotate90d, ResizeWithPadOrCropd, ) train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", # This here needs to be negative spatial_size=(96, 96, -1), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(96, 96, 96), mode='constant' ), RandFlipd( keys=["image", "label"], spatial_axis=[0], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[1], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[2], prob=0.10, ), RandRotate90d( keys=["image", "label"], prob=0.10, max_k=3, ), RandShiftIntensityd( keys=["image"], offsets=0.10, prob=0.50, ), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image", "label"], source_key="image"), ] ) inf_transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Orientationd(keys=["image"], axcodes="RAS"), Spacingd( keys=["image"], pixdim=(1.5, 1.5, 2.0), mode="bilinear", ), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image"], source_key="image"), ] ) # Only run this file directly for debugging trainer = Pipeline(model_type="UNETR", modality=1, num_of_labels=14, model_path="F:\\2404_Organ_Segmentation\\segmentation-pipeline\\best_metric_model_3dUNETR54375.pth", debug=True) trainer.load_data('F:\\2404_Organ_Segmentation\BTCV\Abdomen\RawData\dataset_0.json', train_transforms, val_transforms) trainer.inference(data_folder = 'F:\\2404_Organ_Segmentation\BTCV\Abdomen\RawData\inf_test', output_folder="F:\\2404_Organ_Segmentation\BTCV\Abdomen\RawData\inf_output", transforms=inf_transforms) trainer.train(2,1)