Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tmva/tmva/inc/TMVA/DNN/Architectures/Cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,35 @@ class TCpu

///@}

//____________________________________________________________________________
//
// Average Pooling Layer Propagation
//____________________________________________________________________________
/** @name Forward Propagation in Avg Pooling Layer
*/
///@{

/** Downsample the matrix \p B to the matrix \p A, using avg
* operation
*/
static void DownsampleAvg(TCpuMatrix<AReal> &A, const TCpuMatrix<AReal> &B, size_t imgHeight,
size_t imgWidth, size_t fltHeight, size_t fltWidth, size_t strideRows, size_t strideCols);

///@}

/** @name Backward Propagation in Avg Pooling Layer
*/
///@{
/** Perform the complete backward propagation step in a Pooling Layer. Based on the
* filter sizes used for computing the average, it just forwards the activation
* gradients to the previous layer. */
static void AvgPoolLayerBackward(std::vector<TCpuMatrix<AReal>> &activationGradientsBackward,
const std::vector<TCpuMatrix<AReal>> &activationGradients,
size_t batchSize, size_t depth, size_t nLocalViews,
size_t fltHeight, size_t fltWidth);

///@}

//____________________________________________________________________________
//
// Reshape Layer Propagation
Expand Down
30 changes: 30 additions & 0 deletions tmva/tmva/inc/TMVA/DNN/Architectures/Cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,36 @@ class TCuda

///@}

//____________________________________________________________________________
//
// Average Pooling Layer Propagation
//____________________________________________________________________________
/** @name Forward Propagation in Avg Pooling Layer
*/
///@{

/** Downsample the matrix \p B to the matrix \p A, using avg
* operation
*/
static void DownsampleAvg(TCudaMatrix<AFloat> &A, const TCudaMatrix<AFloat> &B,
const int imgHeight, const int imgWidth, const int fltHeight, const int fltWidth,
const int strideRows, const int strideCols);
///@}

/** @name Backward Propagation in Avg Pooling Layer
*/
///@{

/** Perform the complete backward propagation step in a Pooling Layer. Based on the
* filter sizes used for computing the average, it just forwards the activation
* gradients to the previous layer. */
static void AvgPoolLayerBackward(std::vector<TCudaMatrix<AFloat>> &activationGradientsBackward,
const std::vector<TCudaMatrix<AFloat>> &activationGradients,
size_t batchSize, size_t depth,size_t nLocalViews,
size_t fltHeight, size_t fltWidth);

///@}

//____________________________________________________________________________
//
// Reshape Layer Propagation
Expand Down
31 changes: 31 additions & 0 deletions tmva/tmva/inc/TMVA/DNN/Architectures/Reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,37 @@ class TReference
size_t nLocalViews);
///@}

//____________________________________________________________________________
//
// Average Pooling Layer Propagation
//____________________________________________________________________________
/** @name Forward Propagation in Avg Pooling Layer
*/
///@{

/** Downsample the matrix \p B to the matrix \p A, using avg
* operation
*/
static void DownsampleAvg(TMatrixT<AReal> &A, const TMatrixT<AReal> &B, size_t imgHeight,
size_t imgWidth, size_t fltHeight, size_t fltWidth, size_t strideRows, size_t strideCols);

///@}

/** @name Backward Propagation in Avg Pooling Layer
*/
///@{

/** Perform the complete backward propagation step in a Avg Pooling Layer. Based on the
* filter sizes used for computing the average, it just forwards the actiovation
* gradients to the previous layer. */
static void AvgPoolLayerBackward(std::vector<TMatrixT<AReal>> &activationGradientsBackward,
const std::vector<TMatrixT<AReal>> &activationGradients,
size_t batchSize, size_t depth, size_t nLocalViews,
size_t fltHeight, size_t fltWidth);
///@}



//____________________________________________________________________________
//
// Reshape Layer Propagation
Expand Down
201 changes: 201 additions & 0 deletions tmva/tmva/inc/TMVA/DNN/CNN/AvgPoolLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// @(#)root/tmva/tmva/dnn:$Id$
// Author: Vladimir Ilievski

/**********************************************************************************
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
* Package: TMVA *
* Class : TAvgPoolLayer *
* Web : http://tmva.sourceforge.net *
* *
* Description: *
* Max Pool Deep Neural Network Layer *
* *
* Authors (alphabetical): *
* Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
* *
* Copyright (c) 2005-2015: *
* CERN, Switzerland *
* U. of Victoria, Canada *
* MPI-K Heidelberg, Germany *
* U. of Bonn, Germany *
* *
* Redistribution and use in source and binary forms, with or without *
* modification, are permitted according to the terms listed in LICENSE *
* (http://tmva.sourceforge.net/LICENSE) *
**********************************************************************************/

#ifndef AVGPOOLLAYER_H_
#define AVGPOOLLAYER_H_

#include "TMatrix.h"

#include "TMVA/DNN/GeneralLayer.h"
#include "TMVA/DNN/Functions.h"

#include <iostream>

namespace TMVA {
namespace DNN {
namespace CNN {

/** \class TAvgPoolLayer

Generic Max Pooling Layer class.

This generic Max Pooling Layer Class represents a pooling layer of
a CNN. It inherits all of the properties of the generic virtual base class
VGeneralLayer. In addition to that, it contains a matrix of winning units.

The height and width of the weights and biases is set to 0, since this
layer does not contain any weights.

*/
template <typename Architecture_t>
class TAvgPoolLayer : public VGeneralLayer<Architecture_t> {
public:
using Matrix_t = typename Architecture_t::Matrix_t;
using Scalar_t = typename Architecture_t::Scalar_t;

private:

size_t fFrameHeight; ///< The height of the frame.
size_t fFrameWidth; ///< The width of the frame.

size_t fStrideRows; ///< The number of row pixels to slid the filter each step.
size_t fStrideCols; ///< The number of column pixels to slid the filter each step.

size_t fNLocalViewPixels; ///< The number of pixels in one local image view.
size_t fNLocalViews; ///< The number of local views in one image.

Scalar_t fDropoutProbability; ///< Probability that an input is active.

public:
/*! Constructor. */
TAvgPoolLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Height,
size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols, size_t FrameHeight,
size_t FrameWidth, size_t StrideRows, size_t StrideCols, Scalar_t DropoutProbability);

/*! Copy the max pooling layer provided as a pointer */
TAvgPoolLayer(TAvgPoolLayer<Architecture_t> *layer);

/*! Copy constructor. */
TAvgPoolLayer(const TAvgPoolLayer &);

/*! Destructor. */
~TAvgPoolLayer();

/*! Computes activation of the layer for the given input. The input
* must be in 3D tensor form with the different matrices corresponding to
* different events in the batch. It spatially downsamples the input
* matrices. */
void Forward(std::vector<Matrix_t> &input, bool applyDropout = false);

/*! Depending on the winning units determined during the Forward pass,
* it only forwards the derivatives to the right units in the previous
* layer. Must only be called directly at the corresponding call
* to Forward(...). */
void Backward(std::vector<Matrix_t> &gradients_backward, const std::vector<Matrix_t> &activations_backward,
std::vector<Matrix_t> &inp1, std::vector<Matrix_t> &inp2);

/*! Prints the info about the layer. */
void Print() const;

/*! Getters */
size_t GetFrameHeight() const { return fFrameHeight; }
size_t GetFrameWidth() const { return fFrameWidth; }

size_t GetStrideRows() const { return fStrideRows; }
size_t GetStrideCols() const { return fStrideCols; }

size_t GetNLocalViewPixels() const { return fNLocalViewPixels; }
size_t GetNLocalViews() const { return fNLocalViews; }

Scalar_t GetDropoutProbability() const { return fDropoutProbability; }
};

//______________________________________________________________________________
template <typename Architecture_t>
TAvgPoolLayer<Architecture_t>::TAvgPoolLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
size_t height, size_t width, size_t outputNSlices, size_t outputNRows,
size_t outputNCols, size_t frameHeight, size_t frameWidth,
size_t strideRows, size_t strideCols, Scalar_t dropoutProbability)
: VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, inputDepth, height, width, 0, 0, 0,
0, 0, 0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
fFrameHeight(frameHeight), fFrameWidth(frameWidth), fStrideRows(strideRows),
fStrideCols(strideCols), fNLocalViewPixels(inputDepth * frameHeight * frameWidth), fNLocalViews(height * width),
fDropoutProbability(dropoutProbability)
{
}

//______________________________________________________________________________
template <typename Architecture_t>
TAvgPoolLayer<Architecture_t>::TAvgPoolLayer(TAvgPoolLayer<Architecture_t> *layer)
: VGeneralLayer<Architecture_t>(layer), fFrameHeight(layer->GetFrameHeight()),
fFrameWidth(layer->GetFrameWidth()), fStrideRows(layer->GetStrideRows()), fStrideCols(layer->GetStrideCols()),
fNLocalViewPixels(layer->GetNLocalViewPixels()), fNLocalViews(layer->GetNLocalViews()),
fDropoutProbability(layer->GetDropoutProbability())
{
}

//______________________________________________________________________________
template <typename Architecture_t>
TAvgPoolLayer<Architecture_t>::TAvgPoolLayer(const TAvgPoolLayer &layer)
: VGeneralLayer<Architecture_t>(layer), fFrameHeight(layer.fFrameHeight),
fFrameWidth(layer.fFrameWidth), fStrideRows(layer.fStrideRows), fStrideCols(layer.fStrideCols),
fNLocalViewPixels(layer.fNLocalViewPixels), fNLocalViews(layer.fNLocalViews),
fDropoutProbability(layer.fDropoutProbability)
{
}

//______________________________________________________________________________
template <typename Architecture_t>
TAvgPoolLayer<Architecture_t>::~TAvgPoolLayer()
{
}

//______________________________________________________________________________
template <typename Architecture_t>
auto TAvgPoolLayer<Architecture_t>::Forward(std::vector<Matrix_t> &input, bool applyDropout) -> void
{
for (size_t i = 0; i < this->GetBatchSize(); i++) {

if (applyDropout && (this->GetDropoutProbability() != 1.0)) {
Architecture_t::Dropout(input[i], this->GetDropoutProbability());
}

Architecture_t::DownsampleAvg(this->GetOutputAt(i), input[i], this->GetInputHeight(),
this->GetInputWidth(), this->GetFrameHeight(), this->GetFrameWidth(),
this->GetStrideRows(), this->GetStrideCols());
}
}

//______________________________________________________________________________
template <typename Architecture_t>
auto TAvgPoolLayer<Architecture_t>::Backward(std::vector<Matrix_t> &gradients_backward,
const std::vector<Matrix_t> & /*activations_backward*/,
std::vector<Matrix_t> & /*inp1*/, std::vector<Matrix_t> &
/*inp2*/) -> void
{
Architecture_t::AvgPoolLayerBackward(gradients_backward, this->GetActivationGradients(),
this->GetBatchSize(), this->GetDepth(), this->GetNLocalViews(),
this->GetFrameHeight(), this->GetFrameWidth());
}

//______________________________________________________________________________
template <typename Architecture_t>
auto TAvgPoolLayer<Architecture_t>::Print() const -> void
{
std::cout << "\t\t POOL LAYER: " << std::endl;
std::cout << "\t\t\t Width = " << this->GetWidth() << std::endl;
std::cout << "\t\t\t Height = " << this->GetHeight() << std::endl;
std::cout << "\t\t\t Depth = " << this->GetDepth() << std::endl;

std::cout << "\t\t\t Frame Width = " << this->GetFrameWidth() << std::endl;
std::cout << "\t\t\t Frame Height = " << this->GetFrameHeight() << std::endl;
}

} // namespace CNN
} // namespace DNN
} // namespace TMVA

#endif
54 changes: 54 additions & 0 deletions tmva/tmva/src/DNN/Architectures/Cpu/Propagation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,60 @@ void TCpu<AFloat>::MaxPoolLayerBackward(std::vector<TCpuMatrix<AFloat>> &activat
}
}

//____________________________________________________________________________
template <typename AFloat>
void TCpu<AFloat>::DownsampleAvg(TCpuMatrix<AFloat> &A, const TCpuMatrix<AFloat> &B,
size_t imgHeight, size_t imgWidth, size_t fltHeight, size_t fltWidth, size_t strideRows,
size_t strideCols)
{
// image boudaries
int imgHeightBound = imgHeight - (fltHeight - 1) / 2 - 1;
int imgWidthBound = imgWidth - (fltWidth - 1) / 2 - 1;
size_t currLocalView = 0;

// centers
for (int i = fltHeight / 2; i <= imgHeightBound; i += strideRows) {
for (int j = fltWidth / 2; j <= imgWidthBound; j += strideCols) {
// within local views
for (int m = 0; m < (Int_t)B.GetNrows(); m++) {
AFloat value = 0;

for (int k = i - fltHeight / 2; k <= Int_t(i + (fltHeight - 1) / 2); k++) {
for (int l = j - fltWidth / 2; l <= Int_t(j + (fltWidth - 1) / 2); l++) {
value += B(m, k * imgWidth + l);
}
}
A(m, currLocalView) = value/(fltHeight*fltWidth);
}
currLocalView++;
}
}
}

//____________________________________________________________________________
template <typename AFloat>
void TCpu<AFloat>::AvgPoolLayerBackward(std::vector<TCpuMatrix<AFloat>> &activationGradientsBackward,
const std::vector<TCpuMatrix<AFloat>> &activationGradients,
size_t batchSize, size_t depth, size_t nLocalViews,
size_t fltHeight, size_t fltWidth)
{
for (size_t i = 0; i < batchSize; i++) {
for (size_t j = 0; j < depth; j++) {

// initialize to zeros
for (size_t t = 0; t < (size_t)activationGradientsBackward[i].GetNcols(); t++) {
activationGradientsBackward[i](j, t) = 0;
}

// set values
for (size_t k = 0; k < nLocalViews; k++) {
AFloat grad = activationGradients[i](j, k);
activationGradientsBackward[i](j, k) += grad/(fltHeight*fltWidth);
}
}
}
}

//____________________________________________________________________________
template <typename AFloat>
void TCpu<AFloat>::Reshape(TCpuMatrix<AFloat> &A, const TCpuMatrix<AFloat> &B)
Expand Down
Loading