Skip to content

Commit 8f876db

Browse files
mjanuszcopybara-github
authored andcommitted
Minor API and typo cleanups.
PiperOrigin-RevId: 874530870
1 parent 3dbe9e4 commit 8f876db

6 files changed

Lines changed: 47 additions & 38 deletions

File tree

ffn/input/volume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _update_config_for_augmentation(
153153
)
154154

155155
if vol_cfg.filter_shape != (1, 1, 1) and vol_cfg.filter_shape is not None:
156-
vol_cfg.filer_shape = tuple(
156+
vol_cfg.filter_shape = tuple(
157157
augmentation.input_size_for_rotated_output(
158158
vol_cfg.filter_shape, voxel_size
159159
)
@@ -174,7 +174,7 @@ def _postprocess_augmented_data(
174174
shape = vol_cfg.load_shape
175175

176176
def _update_array(x, name=name, shape=shape):
177-
setattr(x, name, mask.crop(x[name], (0, 0, 0), shape))
177+
x[name] = mask.crop(x[name], (0, 0, 0), shape)
178178
return x
179179

180180
ds = ds.map(

ffn/training/augmentation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def random_contrast_brightness_adjustment(
369369
Returns:
370370
tf.Tensor: Adjusted Tensor.
371371
"""
372-
adjust_tensor = tf.identity(input_tensor)
372+
adjust_tensor = input_tensor
373373
if contrast_factor_range:
374374
min_contrast, max_contrast = contrast_factor_range
375375
contrast_factor = tf.random.uniform([], min_contrast, max_contrast)
@@ -383,7 +383,7 @@ def random_contrast_brightness_adjustment(
383383
if apply_adjustment_to == 'foreground':
384384
adjust_tensor = tf.where(seg_tensor > 0, adjust_tensor, input_tensor)
385385
elif apply_adjustment_to == 'background':
386-
adjust_tensor = tf.where(seg_tensor <= 0, input_tensor, adjust_tensor)
386+
adjust_tensor = tf.where(seg_tensor <= 0, adjust_tensor, input_tensor)
387387
return adjust_tensor
388388

389389

@@ -451,15 +451,15 @@ def __init__(
451451

452452
if self.reflectable_axes.size > 0:
453453
self.reflect_decisions = (
454-
tf.random_uniform([len(self.reflectable_axes)], seed=reflection_seed)
454+
tf.random.uniform([len(self.reflectable_axes)], seed=reflection_seed)
455455
> 0.5
456456
)
457457
self.reflected_axes = tf.boolean_mask(
458458
self.reflectable_axes, self.reflect_decisions
459459
)
460460

461461
if self.permutable_axes.size > 0:
462-
self.permutation = tf.random_shuffle(
462+
self.permutation = tf.random.shuffle(
463463
self.permutable_axes, seed=permutation_seed
464464
)
465465
# full_permutation must be a list rather than an np.array of int32 because

ffn/training/inputs.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Tensorflow Python ops and utilities for generating network inputs."""
1616

17+
import functools
1718
import random
1819
import re
1920
from typing import Any, Callable, Optional, Sequence
@@ -553,7 +554,7 @@ def soften_labels(bool_labels, softness=0.05, scope='soften_labels'):
553554
Tensor with same shape as bool_labels with dtype `float32` and values 0.05
554555
for False and 0.95 for True.
555556
"""
556-
with tf.op_scope([bool_labels, softness], scope):
557+
with tf.name_scope(scope):
557558
label_shape = tf.shape(bool_labels, name='label_shape')
558559
return tf.where(bool_labels,
559560
tf.fill(label_shape, 1.0 - softness, name='soft_true'),
@@ -703,6 +704,35 @@ def sample(
703704
return sampled_dataset
704705

705706

707+
@functools.lru_cache(maxsize=None)
708+
def _parse_bounding_boxes(
709+
volinfo_map_string: str, use_bboxes: bool = True
710+
) -> dict[bytes, list[bounding_box.BoundingBox]]:
711+
boxes_by_volname = {}
712+
for mapping in volinfo_map_string.split(','):
713+
k, volinfo_path = mapping.split(':')
714+
k = k.encode('utf-8')
715+
assert k not in boxes_by_volname
716+
717+
if volinfo_path.endswith('metadata.json'):
718+
f = open(volinfo_path, 'r')
719+
meta = metadata.VolumeMetadata.from_json(f.read())
720+
if use_bboxes:
721+
bboxes = meta.bounding_boxes
722+
else:
723+
bboxes = [
724+
bounding_box.BoundingBox(
725+
(0, 0, 0),
726+
(meta.volume_size.x, meta.volume_size.y, meta.volume_size.z),
727+
)
728+
]
729+
boxes_by_volname[k] = bboxes
730+
731+
if not boxes_by_volname:
732+
raise ValueError('boxes_by_volname is empty.')
733+
return boxes_by_volname
734+
735+
706736
def coordinates_in_bounds(
707737
coordinates: tf.Tensor,
708738
volname: tf.Tensor,
@@ -734,28 +764,7 @@ def coordinates_in_bounds(
734764
coordinates or an empty constant of shape `[0, 3]`, which can then be
735765
passed to batching (e.g. see tests).
736766
"""
737-
boxes_by_volname = {}
738-
for mapping in volinfo_map_string.split(','):
739-
k, volinfo_path = mapping.split(':')
740-
k = k.encode('utf-8')
741-
assert k not in boxes_by_volname
742-
743-
if volinfo_path.endswith('metadata.json'):
744-
f = open(volinfo_path, 'r')
745-
meta = metadata.VolumeMetadata.from_json(f.read())
746-
if use_bboxes:
747-
bboxes = meta.bounding_boxes
748-
else:
749-
bboxes = [
750-
bounding_box.BoundingBox(
751-
(0, 0, 0),
752-
(meta.volume_size.x, meta.volume_size.y, meta.volume_size.z),
753-
)
754-
]
755-
boxes_by_volname[k] = bboxes
756-
757-
if not boxes_by_volname:
758-
raise ValueError('boxes_by_volname is empty.')
767+
boxes_by_volname = _parse_bounding_boxes(volinfo_map_string, use_bboxes)
759768

760769
def _in_bounds_fn(coordinates, volname):
761770
boxes = boxes_by_volname[volname[0]]
@@ -767,8 +776,8 @@ def _in_bounds_fn(coordinates, volname):
767776
return False
768777

769778
with tf.name_scope(name, values=[coordinates, volname]) as scope:
770-
assert coordinates.shape_as_list() == [1, 3]
771-
assert volname.shape_as_list() == [1]
779+
assert coordinates.shape.as_list() == [1, 3]
780+
assert volname.shape.as_list() == [1]
772781
in_bounds = tf.py_func(
773782
_in_bounds_fn,
774783
[coordinates, volname],

ffn/training/mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
"""Utilites for dealing with 2d and 3d object masks."""
16+
"""Utilities for dealing with 2d and 3d object masks."""
1717

1818
from typing import Optional, Sequence
1919

ffn/training/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ def set_up_sigmoid_pixelwise_loss(self, logits):
122122
pixel_loss *= self.loss_weights
123123
self.loss = tf.reduce_mean(pixel_loss)
124124
tf.summary.scalar('pixel_loss', self.loss)
125-
self.loss = tf.verify_tensor_all_finite(self.loss, 'Invalid loss detected')
125+
self.loss = tf.debugging.check_numerics(self.loss, 'Invalid loss detected')
126126

127127
def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7):
128128
"""Sets up the training op for the model."""
129129
if loss is None:
130130
loss = self.loss
131-
tf.summary.scalar('optimizer_loss', self.loss)
131+
tf.summary.scalar('optimizer_loss', loss)
132132

133133
opt = optimizer.optimizer_from_flags()
134134
self.opt = opt

ffn/training/tracker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def get_summaries(self) -> list[tf.Summary.Value]:
361361
for i, summary in enumerate(images):
362362
summary.tag += '/%d' % i
363363

364-
total_moves = sum(self.moves.tf_value)
364+
total_moves = max(sum(self.moves.tf_value), 1)
365365
move_summaries = []
366366
for mt in MoveType:
367367
move_summaries.append(
@@ -377,14 +377,14 @@ def get_summaries(self) -> list[tf.Summary.Value]:
377377
tag='fov/masked_voxel_fraction',
378378
simple_value=(
379379
self.fov_stats.tf_value[FovStat.MASKED_VOXELS]
380-
/ self.fov_stats.tf_value[FovStat.TOTAL_VOXELS]
380+
/ max(self.fov_stats.tf_value[FovStat.TOTAL_VOXELS], 1)
381381
),
382382
),
383383
tf.Summary.Value(
384384
tag='fov/average_weight',
385385
simple_value=(
386386
self.fov_stats.tf_value[FovStat.WEIGHTS_SUM]
387-
/ self.fov_stats.tf_value[FovStat.TOTAL_VOXELS]
387+
/ max(self.fov_stats.tf_value[FovStat.TOTAL_VOXELS], 1)
388388
),
389389
),
390390
tf.Summary.Value(
@@ -418,7 +418,7 @@ def get_summaries(self) -> list[tf.Summary.Value]:
418418
)
419419

420420
for r, r_moves in self.moves_by_r.items():
421-
total_moves = sum(r_moves.tf_value)
421+
total_moves = max(sum(r_moves.tf_value), 1)
422422
summaries.extend([
423423
tf.Summary.Value(
424424
tag='moves/r=%d/correct' % r,

0 commit comments

Comments
 (0)