Skip to content

Instantly share code, notes, and snippets.

@himaprasoon
Created July 11, 2019 11:23
Show Gist options
  • Select an option

  • Save himaprasoon/3ae2388895a4c9cdd3be6cd8fc6e746e to your computer and use it in GitHub Desktop.

Select an option

Save himaprasoon/3ae2388895a4c9cdd3be6cd8fc6e746e to your computer and use it in GitHub Desktop.
Gets all placeholders required to run a tensor op
# source : https://stackoverflow.com/a/47802308/3534616
def get_tensor_dependencies(tensor):
# If a tensor is passed in, get its op
try:
tensor_op = tensor.op
except:
tensor_op = tensor
# Recursively analyze inputs
dependencies = []
for inp in tensor_op.inputs:
new_d = get_tensor_dependencies(inp)
non_repeated = [d for d in new_d if d not in dependencies]
dependencies = [*dependencies, *non_repeated]
# If we've reached the "end", return the op's name
if len(tensor_op.inputs) == 0:
dependencies = [tensor_op.name] if tensor_op.type == 'Placeholder' else []
# Return a list of tensor op names
return dependencies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment