Skip to content

Commit c7ba50c

Browse files
Duyi-Wangchangqi1marvin-Yu
authored
[Op] Add FP32 fused l2 normalize op and grad op. (#291)
Co-authored-by: Li, Changqing <changqing.li@intel.com> Co-authored-by: marvinYu <weifei.yu@intel.com>
1 parent 3d1b640 commit c7ba50c

File tree

13 files changed

+934
-6
lines changed

13 files changed

+934
-6
lines changed

tensorflow/core/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ tf_gen_op_libs(
11811181
"function_ops",
11821182
"functional_ops",
11831183
"fused_embedding_ops",
1184+
"fused_l2_normalize_ops",
11841185
"hash_ops",
11851186
"hash_training_ops",
11861187
"fuserecv_ops",
@@ -1439,6 +1440,7 @@ cc_library(
14391440
":function_ops_op_lib",
14401441
":functional_ops_op_lib",
14411442
":fused_embedding_ops_op_lib",
1443+
":fused_l2_normalize_ops_op_lib",
14421444
":fuserecv_ops_op_lib",
14431445
":hash_ops_op_lib",
14441446
":hash_training_ops_op_lib",
@@ -1623,6 +1625,7 @@ cc_library(
16231625
"//tensorflow/core/kernels:functional_ops",
16241626
"//tensorflow/core/kernels:fused_embedding_ops",
16251627
"//tensorflow/core/kernels/data:parquet_dataset_ops",
1628+
"//tensorflow/core/kernels:fused_l2_normalize_ops",
16261629
"//tensorflow/core/kernels:grappler",
16271630
"//tensorflow/core/kernels:hash_ops",
16281631
"//tensorflow/core/kernels:histogram_op",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
op {
2+
graph_op_name: "FusedL2Normalize"
3+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
op {
2+
graph_op_name: "FusedL2NormalizeGrad"
3+
}

tensorflow/core/kernels/BUILD

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5403,6 +5403,37 @@ tf_cc_test(
54035403
],
54045404
)
54055405

5406+
tf_kernel_library(
5407+
name = "fused_l2_normalize_ops",
5408+
srcs = [
5409+
"fused_l2_normalize/fused_l2_normalize_op.cc",
5410+
],
5411+
hdrs = ["fused_l2_normalize/compile_util.h"],
5412+
deps = ["//third_party/eigen3"] + DYNAMIC_DEPS + mkl_deps(),
5413+
)
5414+
5415+
tf_cc_test(
5416+
name = "fused_l2_normalize_ops_test",
5417+
size = "small",
5418+
srcs = ["fused_l2_normalize/fused_l2_normalize_op_test.cc",
5419+
"fused_l2_normalize/fused_l2_normalize_grad_op_test.cc"],
5420+
deps = [
5421+
":fused_l2_normalize_ops",
5422+
":ops_testutil",
5423+
":ops_util",
5424+
"//tensorflow/cc:cc_ops",
5425+
"//tensorflow/core:core_cpu",
5426+
"//tensorflow/core:framework",
5427+
"//tensorflow/core:framework_internal",
5428+
"//tensorflow/core:lib",
5429+
"//tensorflow/core:protos_all_cc",
5430+
"//tensorflow/core:tensorflow",
5431+
"//tensorflow/core:test",
5432+
"//tensorflow/core:test_main",
5433+
"//tensorflow/core:testlib",
5434+
],
5435+
)
5436+
54065437
tf_kernel_library(
54075438
name = "run_graph_op",
54085439
prefix = "run_graph_op",
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#ifndef TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_
2+
#define TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_
3+
4+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
5+
#include "tensorflow/core/framework/tensor_types.h"
6+
7+
namespace tensorflow {
8+
namespace functor {
9+
10+
#include <type_traits>
11+
12+
// A class for forced loop unrolling at compile time
13+
template <int i>
14+
struct compile_time_for {
15+
template <typename Lambda, typename... Args>
16+
inline static void op(const Lambda& function, Args... args) {
17+
compile_time_for<i-1>::op(function, args...);
18+
function(std::integral_constant<int, i-1>{}, args...);
19+
}
20+
};
21+
template <>
22+
struct compile_time_for<1> {
23+
template <typename Lambda, typename... Args>
24+
inline static void op(const Lambda& function, Args... args) {
25+
function(std::integral_constant<int, 0>{}, args...);
26+
}
27+
};
28+
template <>
29+
struct compile_time_for<0> {
30+
// 0 loops, do nothing
31+
template <typename Lambda, typename... Args>
32+
inline static void op(const Lambda& function, Args... args) {
33+
}
34+
};
35+
36+
} // namespace functor
37+
} // namespace tensorflow
38+
39+
#endif // TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_
40+
41+
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
2+
#include "tensorflow/core/framework/fake_input.h"
3+
#include "tensorflow/core/framework/node_def_builder.h"
4+
#include "tensorflow/core/framework/tensor.h"
5+
#include "tensorflow/core/framework/types.h"
6+
#include "tensorflow/core/kernels/conv_ops_gpu.h"
7+
#include "tensorflow/core/kernels/ops_testutil.h"
8+
#include "tensorflow/core/kernels/ops_util.h"
9+
#include "tensorflow/core/platform/test.h"
10+
#include "tensorflow/core/platform/test_benchmark.h"
11+
#include "tensorflow/core/public/session.h"
12+
#include "tensorflow/cc/ops/standard_ops.h"
13+
14+
namespace tensorflow {
15+
namespace {
16+
17+
enum class Device { CPU, GPU };
18+
19+
class FusedL2NormalizeGradOpTest : public OpsTestBase {
20+
protected:
21+
void MakeOpAndSetDevice(Device device, DataType dtype, int axis, float epsilon) {
22+
TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize_grad", "FusedL2NormalizeGrad")
23+
.Attr("T", dtype)
24+
.Attr("T", dtype)
25+
.Attr("axis", axis)
26+
.Attr("epsilon", epsilon)
27+
.Input(FakeInput(DT_FLOAT))
28+
.Input(FakeInput(DT_FLOAT))
29+
.Finalize(node_def()));
30+
TF_EXPECT_OK(InitOp());
31+
}
32+
};
33+
34+
TEST_F(FusedL2NormalizeGradOpTest, 2Dims_Float) {
35+
const int rows = 4;
36+
const int cols = 252; //128+64+32+16+8+4=252 1008
37+
38+
MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12);
39+
40+
// y_grad
41+
float y_grad_array[1008];
42+
for (int i = 0; i < rows * cols; i++) {
43+
y_grad_array[i] = 1.0;
44+
}
45+
y_grad_array[251] = 2.0;
46+
y_grad_array[503] = 2.0;
47+
y_grad_array[755] = 2.0;
48+
y_grad_array[1007] = 2.0;
49+
AddInputFromArray<float>(TensorShape({rows, cols}), y_grad_array);
50+
51+
// x
52+
float x_array[1008];
53+
for (int i = 0; i < rows * cols; i++) {
54+
x_array[i] = 1.0;
55+
}
56+
AddInputFromArray<float>(TensorShape({rows, cols}), x_array);
57+
58+
TF_ASSERT_OK(RunOpKernel());
59+
TF_EXPECT_OK(device_->Sync());
60+
61+
{
62+
Tensor expected_output(allocator(), DT_FLOAT,
63+
TensorShape({rows, cols}));
64+
float output_array[1008];
65+
for (int i = 0; i < rows * cols; i++) {
66+
output_array[i] = - 1.0 / (252 * std::sqrt(252));
67+
}
68+
output_array[251] = 251.0 / (252 * std::sqrt(252));
69+
output_array[503] = 251.0 / (252 * std::sqrt(252));
70+
output_array[755] = 251.0 / (252 * std::sqrt(252));
71+
output_array[1007] = 251.0 / (252 * std::sqrt(252));
72+
test::FillValues<float>(&expected_output, output_array);
73+
test::ExpectTensorNear<float>(expected_output, *GetOutput(0), 1e-6);
74+
}
75+
}
76+
77+
//----------------------------------------------------------------------------//
78+
// Performance benchmarks //
79+
//----------------------------------------------------------------------------//
80+
static Graph* FusedL2NormalizeGrad(int rows, int cols) {
81+
Graph* g = new Graph(OpRegistry::Global());
82+
DataType dtype = DT_FLOAT;
83+
84+
Tensor in1(dtype, TensorShape({rows, cols}));
85+
in1.flat<float>().setRandom();
86+
Tensor in2(dtype, TensorShape({rows, cols}));
87+
in2.flat<float>().setRandom();
88+
89+
Node* input_in1 = test::graph::Constant(g, in1);
90+
Node* input_in2 = test::graph::Constant(g, in2);
91+
auto nodeBuilder = NodeBuilder(g->NewName("n"), "FusedL2NormalizeGrad")
92+
.Input(input_in1)
93+
.Input(input_in2)
94+
.Attr("T", dtype)
95+
.Attr("axis", 0)
96+
.Attr("epsilon", 1e-12);
97+
TF_CHECK_OK(nodeBuilder.Finalize(g, nullptr));
98+
99+
return g;
100+
}
101+
102+
#define BM_FusedL2NormGrad(ROWS, COLS, NTH) \
103+
static void BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU( \
104+
int iters) { \
105+
testing::UseRealTime(); \
106+
testing::ItemsProcessed(static_cast<int64>(iters) * ROWS * COLS * 5); \
107+
SessionOptions opts; \
108+
opts.config.set_intra_op_parallelism_threads(NTH); \
109+
test::Benchmark("cpu", FusedL2NormalizeGrad(ROWS, COLS), &opts).Run(iters); \
110+
} \
111+
BENCHMARK(BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU); \
112+
113+
#define BM_FusedL2NormGrad_NTH(ROWS, COLS) \
114+
BM_FusedL2NormGrad(ROWS, COLS, 1); \
115+
BM_FusedL2NormGrad(ROWS, COLS, 4); \
116+
BM_FusedL2NormGrad(ROWS, COLS, 8); \
117+
118+
BM_FusedL2NormGrad_NTH(1024, 63);
119+
BM_FusedL2NormGrad_NTH(1024, 127);
120+
BM_FusedL2NormGrad_NTH(1024, 255);
121+
BM_FusedL2NormGrad_NTH(1024, 511);
122+
BM_FusedL2NormGrad_NTH(1024, 1023);
123+
124+
}
125+
}

0 commit comments

Comments
 (0)