diff --git a/.gitignore b/.gitignore index 96ef6c0..04fa3b6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target Cargo.lock +.idea/ + diff --git a/convergence-arrow/src/table.rs b/convergence-arrow/src/table.rs index fb23e98..34f9740 100644 --- a/convergence-arrow/src/table.rs +++ b/convergence-arrow/src/table.rs @@ -4,7 +4,7 @@ use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlSta use convergence::protocol_ext::DataRowBatch; use datafusion::arrow::array::{ BooleanArray, Date32Array, Date64Array, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + Int64Array, Int8Array, StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit}; @@ -48,6 +48,7 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)), DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)), DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)), + DataType::Utf8View => row.write_string(array_val!(StringViewArray, col, row_idx)), DataType::Date32 => { row.write_date(array_val!(Date32Array, col, row_idx, value_as_date).ok_or_else(|| { ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type") @@ -102,7 +103,7 @@ pub fn data_type_to_oid(ty: &DataType) -> Result { DataType::UInt64 => DataTypeOid::Int8, DataType::Float16 | DataType::Float32 => DataTypeOid::Float4, DataType::Float64 => DataTypeOid::Float8, - DataType::Utf8 => DataTypeOid::Text, + DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text, DataType::Date32 | DataType::Date64 => DataTypeOid::Date, DataType::Timestamp(_, None) => DataTypeOid::Timestamp, other => { diff --git a/convergence-arrow/tests/test_arrow.rs b/convergence-arrow/tests/test_arrow.rs index f1dc31e..3a9e8e8 100644 --- a/convergence-arrow/tests/test_arrow.rs +++ b/convergence-arrow/tests/test_arrow.rs @@ -6,7 +6,7 @@ use convergence::protocol_ext::DataRowBatch; use convergence::server::{self, BindOptions}; use convergence::sqlparser::ast::Statement; use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc}; -use datafusion::arrow::array::{ArrayRef, Date32Array, Float32Array, Int32Array, StringArray, TimestampSecondArray}; +use datafusion::arrow::array::{ArrayRef, Date32Array, Float32Array, Int32Array, StringArray, StringViewArray, TimestampSecondArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; use std::sync::Arc; @@ -32,6 +32,7 @@ impl ArrowEngine { let int_col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; let float_col = Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])) as ArrayRef; let string_col = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let string_view_col = Arc::new(StringViewArray::from(vec!["aa", "bb", "cc"])) as ArrayRef; let ts_col = Arc::new(TimestampSecondArray::from(vec![1577836800, 1580515200, 1583020800])) as ArrayRef; let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef; @@ -39,12 +40,13 @@ impl ArrowEngine { Field::new("int_col", DataType::Int32, true), Field::new("float_col", DataType::Float32, true), Field::new("string_col", DataType::Utf8, true), + Field::new("string_view_col", DataType::Utf8View, true), Field::new("ts_col", DataType::Timestamp(TimeUnit::Second, None), true), Field::new("date_col", DataType::Date32, true), ]); Self { - batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, string_col, ts_col, date_col]) + batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, string_col, string_view_col, ts_col, date_col]) .expect("failed to create batch"), } } @@ -89,8 +91,8 @@ async fn basic_data_types() { let rows = client.query("select 1", &[]).await.unwrap(); let get_row = |idx: usize| { let row = &rows[idx]; - let cols: (i32, f32, &str, NaiveDateTime, NaiveDate) = - (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4)); + let cols: (i32, f32, &str, &str, NaiveDateTime, NaiveDate) = + (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4), row.get(5)); cols }; @@ -100,6 +102,7 @@ async fn basic_data_types() { 1, 1.5, "a", + "aa", NaiveDate::from_ymd_opt(2020, 1, 1) .unwrap() .and_hms_opt(0, 0, 0) @@ -113,6 +116,7 @@ async fn basic_data_types() { 2, 2.5, "b", + "bb", NaiveDate::from_ymd_opt(2020, 2, 1) .unwrap() .and_hms_opt(0, 0, 0) @@ -126,6 +130,7 @@ async fn basic_data_types() { 3, 3.5, "c", + "cc", NaiveDate::from_ymd_opt(2020, 3, 1) .unwrap() .and_hms_opt(0, 0, 0)