3 votes

Comment vérifier si une matrice est inversible dans tensorflow ?

Dans mon graphe Tensorflow, je voudrais inverser une matrice ; si elle est inversable, faire quelque chose avec. Si elle n'est pas inversible, le, je voudrais faire autre chose.

Je n'ai pas trouvé de moyen de vérifier si la matrice est inversible afin de faire quelque chose comme :

is_invertible = tf.is_invertible(mat)
tf.cond(is_invertible, f1, f2)

Existe-t-il une chose telle qu'un is_invertible dans Tensorflow ? J'ai aussi envisagé d'utiliser le fait que Tensorflow lève (pas à chaque fois cependant) une fonction InvalidArgumentError lorsque j'essaie d'inverser une matrice non inversable, mais je n'ai pas pu en tirer parti.

3voto

jdehesa Points 22254

Comme proposé dans Vérification efficace et pythique des matrices singulières vous pouvez vérifier le numéro de la condition. Malheureusement, ceci n'est pas actuellement implémenté dans TensorFlow en tant que tel, mais il n'est pas difficile d'émuler l'implémentation de base de np.linalg.cond :

import math
import tensorflow as tf

# Based on np.linalg.cond(x, p=None)
def tf_cond(x):
    x = tf.convert_to_tensor(x)
    s = tf.linalg.svd(x, compute_uv=False)
    r = s[..., 0] / s[..., -1]
    # Replace NaNs in r with infinite unless there were NaNs before
    x_nan = tf.reduce_any(tf.is_nan(x), axis=(-2, -1))
    r_nan = tf.is_nan(r)
    r_inf = tf.fill(tf.shape(r), tf.constant(math.inf, r.dtype))
    tf.where(x_nan, r, tf.where(r_nan, r_inf, r))
    return r

def is_invertible(x, epsilon=1e-6):  # Epsilon may be smaller with tf.float64
    x = tf.convert_to_tensor(x)
    eps_inv = tf.cast(1 / epsilon, x.dtype)
    x_cond = tf_cond(x)
    return tf.is_finite(x_cond) & (x_cond < eps_inv)

m = [
    # Invertible matrix
    [[1., 2., 3.],
     [6., 5., 4.],
     [7., 7., 8.]],
    # Non-invertible matrix
    [[1., 2., 3.],
     [6., 5., 4.],
     [7., 7., 7.]],
]
with tf.Graph().as_default(), tf.Session() as sess:
    print(sess.run(is_invertible(m)))
    # [ True False]

0voto

zwu381 Points 1

Qu'en est-il

if m.shape[0] == m.shape[1]:
    # when m is a square matrix
    try:
        tf.linalg.inv(L)
        print(tf.linalg.inv(L))
        print("invertible")
    except:
        print("not invertible")
else:
    # m is not invertible if it is not even a squared matrix
    print("not invertible")

Puisque lorsque vous essayez d'inverser une matrice non inversible avec tf.linalg.inv , un InvalidArgumentError est lancée uniquement lorsque l'entrée, en supposant qu'elle soit 2d, n'est pas une matrice carrée. Et quand ce n'est pas une matrice carrée, elle est carrément non-invertible.

0 votes

La question est assez ancienne, mais si je me souviens bien, j'ai eu cette erreur InvalidArgumentError avec des matrices carrées également.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X