Skip to content

Commit

Permalink
fixing moshiko PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
Itai Guez [email protected] committed May 24, 2023
1 parent 4ec336a commit fea5332
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ train:
learning_rate : 1e-3
weight_decay : 0
run_sample : 0 #if 0 - used all samples otherwise sample only run_sample samples ( for test purpouse )
resume_checkpoint_filename :
resume_checkpoint_filename : null
trainer:
accelerator : gpu
devices : 1
num_epochs : 100
ckpt_path :
ckpt_path : null
unet_kwargs :
strides : [[2, 2, 2], [1, 2, 2], [1, 2, 2], [1, 2, 2], [2, 2, 2]]
channels : [32, 64, 128, 256, 512, 1024]
Expand Down
7 changes: 3 additions & 4 deletions fuse_examples/imaging/segmentation/picai/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def run_train(paths: NDict, train: NDict) -> None:
lgr.info("- Create sampler:")
sampler = BatchSamplerDefault(
dataset=train_dataset,
balanced_class_name="data.gt.classification", # gt_label, TODO - make diff label for balance-sampler
balanced_class_name="data.gt.classification",
num_balanced_classes=2,
batch_size=train["batch_size"],
mode="approx",
Expand Down Expand Up @@ -199,7 +199,7 @@ def run_train(paths: NDict, train: NDict) -> None:

# either a dict with arguments to pass to ModelCheckpoint or list dicts for multiple ModelCheckpoint callbacks (to monitor and save checkpoints for more then one metric).
best_epoch_source = dict(
monitor="validation.losses.total_loss", # metrics.auc.macro_avg",
monitor="validation.losses.total_loss",
mode="min",
)

Expand Down Expand Up @@ -331,7 +331,6 @@ def run_eval(paths: NDict, infer: NDict) -> None:
# define iterator

def data_iter() -> NDict:
# set seed
data_file = os.path.join(paths["inference_dir"], "infer.pickle")
data = pd.read_pickle(data_file)
for fold in data:
Expand Down Expand Up @@ -372,7 +371,7 @@ def main(cfg: DictConfig) -> None:
if "infer" in cfg["run.running_modes"]:
run_infer(NDict(cfg["infer"]), NDict(cfg["paths"]), NDict(cfg["train"]))

# analyze - skipping as it crushes without metrics
# analyze
if "eval" in cfg["run.running_modes"]:
run_eval(NDict(cfg["paths"]), NDict(cfg["infer"]))

Expand Down
11 changes: 5 additions & 6 deletions fuseimg/datasets/picai.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,12 @@ def dataset(
) -> DatasetDefault:
"""
Creates Fuse Dataset single object (either for training, validation and test or user defined set)
:param data_dir: dataset root path
:param clinical_file path to clinical_file
:param target target name used from the ground truth dataframe
:param cache_dir: Optional, name of the cache folder
:param paths paths dictionary for dataset files
:param cfg dict cfg for training phase
:param reset_cache: Optional,specifies if we want to clear the cache first
:param sample_ids: dataset including the specified sample_ids or None for all the samples. sample_id is case_{id:05d} (for example case_00001 or case_00100).
:param train: True if used for training - adds augmentation operations to the pipeline
:param sample_ids: dataset including the specified sample_ids or None for all the samples.
:param train: True if used for training - adds augmentation operations to the pipeline
:param run_sample: if > 0 it samples from all the samples #run_sample examples ( used for testing), if =0 then it takes all samples
:return: DatasetDefault object
"""

Expand Down

0 comments on commit fea5332

Please sign in to comment.