Skip to content

[Improvement] Ensure the graph is connected before connecting geodesics #28

Open
@a-pouplin

Description

@a-pouplin

Is your feature related to a problem?
When plotting geodesics, calling connecting_geodesic, the function assumes that the graph is connected. When not connected, networks.shortest_path throws after some time an error that can be seen as cryptic by the users.

Describe the solution you would like:
Either when discretising the manifold, or before plotting the geodesics, a warning message can be added to check that the graph is connected: assert nx.is_connected(graph), "Graph not connected". Additional guidance or best practices can be added to ensure the graph is connected.

Activity

changed the title [-][Improvement] Unsure the graph is connected before connecting geodesics[/-] [+][Improvement] Ensure the graph is connected before connecting geodesics[/+] on Mar 20, 2023
a-pouplin

a-pouplin commented on Mar 20, 2023

@a-pouplin
Author

Another improvement would be to check if the size are as expected. For example, in DiscretizedManifold.fit() function:

with torch.no_grad():
   weight = model.curve_length(line(t))
   assert weight.shape == bs, f"model.curve_length should return a {bs} shape object but found {weight.shape}."
a-pouplin

a-pouplin commented on Mar 20, 2023

@a-pouplin
Author

Another question regarding DiscretizedManifold.fit(): a graph is created based on two points obtained from a curve:

t = torch.linspace(0, 1, 2)

(...)

with torch.no_grad():
   weight = model.curve_length(line(t))

and this method mainly relies on giving a metric tensor to compute the graph (curve_length depends on inner_product which depends on metric). Yet, when the metric tensor is not easily accessible, once might want to compute the curve lenght based on the derivatives ( γ ˙ ) of the curve: L [ γ ] = γ t ˙   d t . Derivatives can be nicely computed only if the curve is discretised enough.

The question is: would it make sense to add an argument (ex: num_curve_points) to discretise the curve and use the derivatives to compute the expected metric (or a Finsler metric for example)?

def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0, num_curve_points=2):

   (...)

   t = torch.linspace(0, 1, num_curve_points)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @a-pouplin

        Issue actions

          [Improvement] Ensure the graph is connected before connecting geodesics · Issue #28 · MachineLearningLifeScience/stochman