From f16893403ffab57f9ccaa76daee9721a633fe3c9 Mon Sep 17 00:00:00 2001 From: Raiki Tamura Date: Sat, 11 May 2024 19:43:04 +0900 Subject: [PATCH] Export metadata type information Signed-off-by: Raiki Tamura --- src/lib.rs | 2 +- src/query.rs | 18 ++++++++++++++---- src/row.rs | 41 +++++++++++++++++++++++++++++++++++------ tests/test-chunk.rs | 18 +++++++++++++++++- 4 files changed, 67 insertions(+), 12 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 30dd71b..614042a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,7 +44,7 @@ mod session; use std::time::Duration; pub use error::{Error, Result}; -pub use row::{SnowflakeDecode, SnowflakeRow}; +pub use row::{SnowflakeColumn, SnowflakeColumnType, SnowflakeDecode, SnowflakeRow}; pub use session::SnowflakeSession; use auth::login; diff --git a/src/query.rs b/src/query.rs index b0ff62d..1768dd5 100644 --- a/src/query.rs +++ b/src/query.rs @@ -8,6 +8,7 @@ use http::{ use reqwest::Client; use tokio::time::sleep; +use crate::row::SnowflakeColumnType; use crate::{chunk::download_chunk, Error, Result, SnowflakeRow}; pub(super) const SESSION_EXPIRED: &str = "390112"; @@ -102,17 +103,26 @@ pub(super) async fn query>( row_set.extend(rows); } - let column_names = row_types + let column_types = row_types .into_iter() .enumerate() - .map(|(i, name)| (name.name.to_ascii_uppercase(), i)) + .map(|(i, row_type)| { + ( + row_type.name.to_ascii_uppercase(), + SnowflakeColumnType { + index: i, + snowflake_type: row_type.data_type, + nullable: row_type.nullable, + }, + ) + }) .collect::>(); - let column_names = Arc::new(column_names); + let column_types = Arc::new(column_types); Ok(row_set .into_iter() .map(|row| SnowflakeRow { row, - column_names: Arc::clone(&column_names), + column_types: Arc::clone(&column_types), }) .collect()) } diff --git a/src/row.rs b/src/row.rs index 87e7b52..bc88fcc 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,25 +1,54 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, ops::Deref, sync::Arc}; use chrono::{DateTime, Days, NaiveDate, NaiveDateTime}; use crate::{Error, Result}; +#[derive(Debug)] +pub struct SnowflakeColumn { + pub name: String, + pub column_type: SnowflakeColumnType, +} + +#[derive(Debug, Clone)] +pub struct SnowflakeColumnType { + /// The index of the column in the row + pub index: usize, + /// Data type of the column in Snowflake + pub snowflake_type: String, + /// Whether the column is nullable + pub nullable: bool, +} + #[derive(Debug)] pub struct SnowflakeRow { pub(crate) row: Vec>, - pub(crate) column_names: Arc>, + pub(crate) column_types: Arc>, } impl SnowflakeRow { pub fn get(&self, column_name: &str) -> Result { - let index = self - .column_names + let column_type = self + .column_types .get(&column_name.to_ascii_uppercase()) .ok_or_else(|| Error::Decode(format!("column not found: {}", column_name)))?; - self.row[*index].try_get() + self.row[column_type.index].try_get() } pub fn column_names(&self) -> Vec<&str> { - self.column_names.iter().map(|(k, _)| k.as_str()).collect() + self.column_types.iter().map(|(k, _)| k.as_str()).collect() + } + pub fn column_types(&self) -> Vec { + let column_types = self.column_types.deref(); + let mut v: Vec<_> = column_types + .iter() + .map(|(k, v)| SnowflakeColumn { + name: k.clone(), + column_type: v.clone(), + }) + .collect(); + // sort by column index + v.sort_by_key(|c| c.column_type.index); + v } } diff --git a/tests/test-chunk.rs b/tests/test-chunk.rs index 7e77147..14d962d 100644 --- a/tests/test-chunk.rs +++ b/tests/test-chunk.rs @@ -1,4 +1,6 @@ -use snowflake_connector_rs::{Result, SnowflakeAuthMethod, SnowflakeClient, SnowflakeClientConfig}; +use snowflake_connector_rs::{ + Result, SnowflakeAuthMethod, SnowflakeClient, SnowflakeClientConfig, SnowflakeColumnType, +}; #[tokio::test] async fn test_download_chunked_results() -> Result<()> { @@ -38,5 +40,19 @@ async fn test_download_chunked_results() -> Result<()> { assert!(rows[0].column_names().contains(&"SEQ")); assert!(rows[0].column_names().contains(&"RAND")); + let columns = rows[0].column_types(); + assert_eq!( + columns[0].column_type.snowflake_type.to_ascii_uppercase(), + "FIXED" + ); + assert_eq!(columns[0].column_type.nullable, false); + assert_eq!(columns[0].column_type.index, 0); + assert_eq!( + columns[1].column_type.snowflake_type.to_ascii_uppercase(), + "TEXT" + ); + assert_eq!(columns[1].column_type.nullable, false); + assert_eq!(columns[1].column_type.index, 1); + Ok(()) }