3 votes

Je reçois une erreur de l'objet 'ModifiedTensorBoard' qui n'a pas d'attribut '_write_logs' en exécutant ce code pour former un Deep Q Network.

J'exécute le code suivant :

class ModifiedTensorBoard(TensorBoard):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.writer = tf.summary.create_file_writer(self.log_dir)

    def set_model(self, model):
        pass

    def on_epoch_end(self, epoch, logs=None):
        self.update_stats(**logs)

    def on_batch_end(self, batch, logs=None):
        pass

    def on_train_end(self, _):
        pass

    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

L'erreur est obtenue :

    232     def update_stats(self, **stats):
--> 233         self._write_logs(stats, self.step)
    234 
    235 

AttributeError: 'ModifiedTensorBoard' object has no attribute '_write_logs'

Quelqu'un peut-il suggérer comment surmonter ce problème ? Merci.

0voto

Mingming Qiu Points 11

Créer une fonction "_write_logs" comme suit :

def update_stats(self, **stats):
    self._write_logs(stats, self.step)

def _write_logs(self, logs, index):
    with self.writer.as_default():
        for name, value in logs.items():
            tf.summary.scalar(name, value, step=index)
            self.step += 1
            self.writer.flush()

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X