Ask Your Question

Revision history [back]

click to hide/show revision 1
initial version

To retrieve the shape of all tensors in a graph using c++ in Tensorflow, you can use the following steps:

  1. Load the Tensorflow graph in your c++ application.
  2. Create a Tensorflow session using the graph.
  3. Get the Tensorflow graph's list of operations using the Graph::get_operations() method.
  4. Iterate over each operation, and for each operation, iterate over its input tensors and output tensors.
  5. For each input/output tensor, use the Tensor::shape() method to retrieve its shape.

Here's an example code snippet that demonstrates these steps:

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/graph/graph.h"
#include <iostream>

using namespace tensorflow;

int main() {
  // Load Tensorflow graph
  GraphDef graph_def;
  ReadBinaryProto(Env::Default(), "path/to/your/graph.pb", &graph_def);

  // Create session and load graph
  SessionOptions options;
  std::unique_ptr<Session> session(NewSession(options));
  Status status = session->Create(graph_def);
  if (!status.ok()) {
    std::cerr << "Error creating Tensorflow session: " << status.ToString() << std::endl;
    return 1;
  }

  // Get list of all operations in the graph
  std::vector<tensorflow::TensorShape> shapes;
  const Graph& graph = session->graph();
  for (const auto& op : graph.op_nodes()) {
      // Iterate over input tensors
      for (const auto& input : op->inputs()) {
          auto tensor = input.template tensor<T>();
          TensorShape shape = tensor.shape();
          shapes.push_back(shape);
          std::cout << "Input tensor shape: " << shape.DebugString() << std::endl;
      }

      // Iterate over output tensors
      for (const auto& output : op->outputs()) {
          auto tensor = output.template tensor<T>();
          TensorShape shape = tensor.shape();
          shapes.push_back(shape);
          std::cout << "Output tensor shape: " << shape.DebugString() << std::endl;
      }
  }

  return 0;
}