20
Июл
2021

tf.Variable созданный в __call__ слоя. tensorflow, python

У меня есть мой слой:

class Laplacian(tf.keras.layers.Layer):
  def __init__(self):
      super(Laplacian, self).__init__()

  def build(self, input_shape):
      self.filter = tf.constant([[0, -1, 0],
                                [-1, 4, -1],
                                [0, -1, 0]], dtype=tf.float32)

  def call(self, tensor):
      stride = (1, 1)
      channels = tensor.shape[3]
      n1 = tf.cast(((tf.shape(tensor)[1]-self.filter.shape[0])/stride[0]+1), dtype=tf.int32)
      n2 = tf.cast(((tf.shape(tensor)[2]-self.filter.shape[1])/stride[1]+1), dtype=tf.int32)
      result = tf.Variable(tf.zeros((tf.shape(tensor)[0], n1, n2, channels)), dtype=tf.float32, trainable=False)

      for ch in range(channels):
          for row in range(0, tf.shape(tensor)[1], stride[0]):
              for col in range(0, tf.shape(tensor)[2], stride[1]):
                  if (row+self.filter.shape[0]-1<tf.shape(tensor)[1] and col+self.filter.shape[1]-1<tf.shape(tensor)[2]):
                      temp_mtx = tensor[:, row:self.filter.shape[0]+row, col:self.filter.shape[1]+col, ch]
                      val = tf.tensordot(temp_mtx, self.filter, 2)
                      #print(val)
                      result[:, row//stride[0], col//stride[1], ch].assign(val)
      return result

когда я хочу добавить его в свою модель, он ругается на tf.Variable:

avg_pool = tf.keras.layers.AveragePooling2D(pool_size=(3, 3), strides=(3, 3), padding='valid', name='avg_pooling2D')(vgg.layers[-2].output) 
lap_layer = Laplacian()(avg_pool)

Вот такая ошибка:

ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., tf.Variable(lambda : tf.truncated_normal([10, 40]))) when building functions. Please file a feature request if this restriction inconveniences you.

Если я оборачиваю variable в lambda, то непонятно, как мне к ней обращаться и изменять элементы моего тензора result.

Источник: https://ru.stackoverflow.com/questions/1306868/tf-variable-%D1%81%D0%BE%D0%B7%D0%B4%D0%B0%D0%BD%D0%BD%D1%8B%D0%B9-%D0%B2-call-%D1%81%D0%BB%D0%BE%D1%8F-tensorflow-python

Тебе может это понравится...

Добавить комментарий