We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9e30ab6 commit d1c5a6eCopy full SHA for d1c5a6e
1 file changed
tensorflow/python/training/checkpoint_utils.py
@@ -443,7 +443,8 @@ def _set_checkpoint_initializer(variable,
443
is_partitioned_ev = variable._save_slice_info is not None
444
partition_id = variable._save_slice_info.var_offset[0] if is_partitioned_ev else 0
445
partition_num = variable._save_slice_info.full_shape[0] if is_partitioned_ev else 1
446
- with ops.control_dependencies([variable._initializer_op]):
+ restore_dependency = ops.get_collection(ops.GraphKeys.EMBEDDING_VARIABLE_RESTORE_DEPENDENCY)[0]
447
+ with ops.control_dependencies(restore_dependency[variable._primary_handle]):
448
rank = variable.initial_value.get_shape().rank - 1
449
restore_op = gen_kv_variable_ops.kv_resource_import_v3(
450
ckpt_file,
0 commit comments