Le tenseur de sortie est défini comme suit :
// Grab the input tensor
const Tensor& I_tensor = context->input(0);
const Tensor& W_tensor = context->input(1);
auto Input = I_tensor.flat<float>();
auto Weight = W_tensor.flat<float>();
// OP_REQUIRES(context, iA_tensor.dims()==2 && iB_tensor.dims()==2);
int B = I_tensor.dim_size(1);
int nH= I_tensor.dim_size(2);
int nW= I_tensor.dim_size(3);
int C = I_tensor.dim_size(4);
int K = W_tensor.dim_size(2);
// Create an output tensor
Tensor* O_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{B, 2*nH, 2*nW, K}, &O_tensor));
J'ai essayé d'utiliser le même code à l'intérieur du SetShapeFn, mais cela ne fonctionne pas.
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
// // const Tensor& I_tensor = c->input(0); //This doesn't work
// // const Tensor& W_tensor = c->input(1); //This doesn't work
::tensorflow::shape_inference::ShapeHandle I = c->input(0);
::tensorflow::shape_inference::ShapeHandle W = c->input(1);
//I have no idea how to get the desired shape from I and W
return Status::OK();
})
J'ai essayé de trouver plus d'informations sur ShapeHandle et InferenceContext, afin d'obtenir le résultat souhaité, mais je n'y suis pas parvenu. Je serai reconnaissant si quelqu'un peut m'aider à ce sujet.
Merci !