Skip to content
72 changes: 72 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
std::numeric_limits<int32_t>::min(), &out_len);
EXPECT_EQ(std::string(out_str, out_len), "a.b.c");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "a", -2, ".", -1, -50, &out_len);
EXPECT_STREQ(out_str, "");
EXPECT_EQ(out_len, 0);
}

TEST(TestGdvFnStubs, TestUpper) {
Expand Down Expand Up @@ -640,6 +644,26 @@ TEST(TestGdvFnStubs, TestUpper) {
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr(
"unexpected byte \\c3 encountered while decoding utf8 string"));

ctx.Reset();

// Max Len Test
out_len = -1;
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len);
// Expect failure
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

// Negative length test
out_len = -1;
out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

std::string e(
Expand Down Expand Up @@ -697,6 +721,26 @@ TEST(TestGdvFnStubs, TestLower) {
out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "");
EXPECT_FALSE(ctx.has_error());
ctx.Reset();

// Max Len Test
out_len = -1;
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len);
// Expect failure
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

// Negative length test
out_len = -1;
out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

std::string d("AbOJjÜoß\xc3");
out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
Expand Down Expand Up @@ -796,6 +840,25 @@ TEST(TestGdvFnStubs, TestInitCap) {
"unexpected byte \\c3 encountered while decoding utf8 string"));
ctx.Reset();

// Max Len Test
out_len = -1;
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len);
// Expect failure
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
ctx.Reset();

// Negative length test
out_len = -1;
out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(out, "");
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
ctx.Reset();

std::string e(
"åbÑg\xe0\xa0"
"åBUå");
Expand Down Expand Up @@ -1127,6 +1190,15 @@ TEST(TestGdvFnStubs, TestTranslate) {
result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, "0123456789",
10, &out_len);
EXPECT_EQ(expected, std::string(result, out_len));

int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 4 + 1;
out_len = -1;
result =
translate_utf8_utf8_utf8(ctx_ptr, "ABCDE", bad_in_len, "B", 1, "C", 1, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_STREQ(result, "");
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Would overflow maximum output size"));
}

TEST(TestGdvFnStubs, TestToUtcTimezone) {
Expand Down
101 changes: 78 additions & 23 deletions cpp/src/gandiva/gdv_string_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) {
return 0;
}

static inline bool is_datalen_valid(int64_t context, int32_t data_len, int32_t* alloc_len,
int32_t* out_len) {
// Reject negative lengths
if (ARROW_PREDICT_FALSE(data_len < 0)) {
gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
*out_len = 0;
return false;
}

// Check overflow: 2 * data_len
if (ARROW_PREDICT_FALSE(
arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) {
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
*out_len = 0;
return false;
}
return true;
}

// Convert an utf8 string to its corresponding lowercase string
GANDIVA_EXPORT
const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len,
Expand All @@ -222,10 +241,15 @@ const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_le
return "";
}

int32_t alloc_length = 0;
if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length, out_len))) {
return "";
}

// If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte
// long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of
// the output can be at most twice the length of the input
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
Expand Down Expand Up @@ -294,10 +318,15 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
return "";
}

int32_t alloc_length = 0;
if (ARROW_PREDICT_FALSE(!is_datalen_valid(context, data_len, &alloc_length, out_len))) {
return "";
}

// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
// the output can be at most twice the length of the input
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
Expand Down Expand Up @@ -367,6 +396,15 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
return "";
}

if (ARROW_PREDICT_FALSE(txt_len < 0)) {
*out_len = 0;
return "";
}
if (ARROW_PREDICT_FALSE(pat_len < 0)) {
*out_len = 0;
return "";
}

char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, txt_len));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
Expand Down Expand Up @@ -445,8 +483,8 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
return out;

} else {
memcpy(out, txt, static_cast<size_t>(txt_len));
*out_len = txt_len;
memcpy(out, txt, txt_len);
return out;
}
}
Expand Down Expand Up @@ -480,10 +518,16 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_
return "";
}

int32_t alloc_length = 0;
if (ARROW_PREDICT_FALSE(
!is_datalen_valid(context, data_len, &alloc_length, out_len))) {
return "";
}

// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
// the output can be at most twice the length of the input
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
if (out == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
Expand Down Expand Up @@ -579,15 +623,24 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
return in;
}

int32_t alloc_length = 0;
// Check overflow: 4 * in_len
if (ARROW_PREDICT_FALSE(
arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) {
Comment on lines +626 to +629
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
*out_len = 0;
return "";
}

// This variable is to control if there are multi-byte utf8 entries
bool has_multi_byte = false;

// This variable is to store the final result
char* result;
int result_len;
int32_t result_len;

// Searching multi-bytes in In
for (int i = 0; i < in_len; i++) {
for (int32_t i = 0; i < in_len; i++) {
unsigned char char_single_byte = in[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
Expand All @@ -598,7 +651,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// Searching multi-bytes in From
if (!has_multi_byte) {
for (int i = 0; i < from_len; i++) {
for (int32_t i = 0; i < from_len; i++) {
unsigned char char_single_byte = from[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
Expand All @@ -610,7 +663,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// Searching multi-bytes in To
if (!has_multi_byte) {
for (int i = 0; i < to_len; i++) {
for (int32_t i = 0; i < to_len; i++) {
unsigned char char_single_byte = to[i];
if (char_single_byte > 127) {
// found a multi-byte utf-8 char
Expand Down Expand Up @@ -638,7 +691,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// This variable is for controlling the position in entry TO, for never repeat the
// changes
int start_compare;
int32_t start_compare;

if (to_len > 0) {
start_compare = 0;
Expand All @@ -650,15 +703,15 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
// list, to mark deletion positions
const char empty = '\0';

for (int in_for = 0; in_for < in_len; in_for++) {
for (int32_t in_for = 0; in_for < in_len; in_for++) {
if (subs_list.find(in[in_for]) != subs_list.end()) {
if (subs_list[in[in_for]] != empty) {
// If exist in map, only add the correspondent value in result
result[result_len] = subs_list[in[in_for]];
result_len++;
}
} else {
for (int from_for = 0; from_for <= from_len; from_for++) {
for (int32_t from_for = 0; from_for <= from_len; from_for++) {
if (from_for == from_len) {
// If it's not in the FROM list, just add it to the map and the result.
subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
Expand Down Expand Up @@ -686,10 +739,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
}
}
}
} else { // If there are no multibytes in the input, work with std::strings
} else {
// If there are multibytes in the input, work with std::strings
// This variable is for receive the substitutions, malloc is in_len * 4 to receive
// possible inputs with 4 bytes
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len * 4));
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));

if (result == nullptr) {
gdv_fn_context_set_error_msg(context,
Expand All @@ -704,7 +758,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in

// This variable is for controlling the position in entry TO, for never repeat the
// changes
int start_compare;
int32_t start_compare;

if (to_len > 0) {
start_compare = 0;
Expand All @@ -717,11 +771,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
const std::string empty = "";

// This variables is to control len of multi-bytes entries
int len_char_in = 0;
int len_char_from = 0;
int len_char_to = 0;
int32_t len_char_in = 0;
int32_t len_char_from = 0;
int32_t len_char_to = 0;

for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) {
// Updating len to char in this position
len_char_in = gdv_fn_utf8_char_length(in[in_for]);
// Making copy to std::string with length for this char position
Expand All @@ -734,11 +788,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
result_len += static_cast<int>(subs_list[insert_copy_key].length());
}
} else {
for (int from_for = 0; from_for <= from_len; from_for += len_char_from) {
// Updating len to char in this position
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
// Making copy to std::string with length for this char position
std::string copy_from_compare(from + from_for, len_char_from);
for (int32_t from_for = 0; from_for <= from_len; from_for += len_char_from) {
if (from_for == from_len) {
// If it's not in the FROM list, just add it to the map and the result.
std::string insert_copy_value(in + in_for, len_char_in);
Expand All @@ -751,6 +801,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
break;
}

// Updating len to char in this position
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
// Making copy to std::string with length for this char position
std::string copy_from_compare(from + from_for, len_char_from);

if (insert_copy_key != copy_from_compare) {
// If this character does not exist in FROM list, don't need treatment
continue;
Expand Down
Loading
Loading