Skip to content

Commit

Permalink
Export metadata type information
Browse files Browse the repository at this point in the history
Signed-off-by: Raiki Tamura <[email protected]>
  • Loading branch information
tamaroning committed May 17, 2024
1 parent 24a8cae commit f168934
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mod session;
use std::time::Duration;

pub use error::{Error, Result};
pub use row::{SnowflakeDecode, SnowflakeRow};
pub use row::{SnowflakeColumn, SnowflakeColumnType, SnowflakeDecode, SnowflakeRow};
pub use session::SnowflakeSession;

use auth::login;
Expand Down
18 changes: 14 additions & 4 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use http::{
use reqwest::Client;
use tokio::time::sleep;

use crate::row::SnowflakeColumnType;
use crate::{chunk::download_chunk, Error, Result, SnowflakeRow};

pub(super) const SESSION_EXPIRED: &str = "390112";
Expand Down Expand Up @@ -102,17 +103,26 @@ pub(super) async fn query<Q: Into<QueryRequest>>(
row_set.extend(rows);
}

let column_names = row_types
let column_types = row_types
.into_iter()
.enumerate()
.map(|(i, name)| (name.name.to_ascii_uppercase(), i))
.map(|(i, row_type)| {
(
row_type.name.to_ascii_uppercase(),
SnowflakeColumnType {
index: i,
snowflake_type: row_type.data_type,
nullable: row_type.nullable,
},
)
})
.collect::<HashMap<_, _>>();
let column_names = Arc::new(column_names);
let column_types = Arc::new(column_types);
Ok(row_set
.into_iter()
.map(|row| SnowflakeRow {
row,
column_names: Arc::clone(&column_names),
column_types: Arc::clone(&column_types),
})
.collect())
}
Expand Down
41 changes: 35 additions & 6 deletions src/row.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,54 @@
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, ops::Deref, sync::Arc};

use chrono::{DateTime, Days, NaiveDate, NaiveDateTime};

use crate::{Error, Result};

#[derive(Debug)]
pub struct SnowflakeColumn {
pub name: String,
pub column_type: SnowflakeColumnType,
}

#[derive(Debug, Clone)]
pub struct SnowflakeColumnType {
/// The index of the column in the row
pub index: usize,
/// Data type of the column in Snowflake
pub snowflake_type: String,
/// Whether the column is nullable
pub nullable: bool,
}

#[derive(Debug)]
pub struct SnowflakeRow {
pub(crate) row: Vec<Option<String>>,
pub(crate) column_names: Arc<HashMap<String, usize>>,
pub(crate) column_types: Arc<HashMap<String, SnowflakeColumnType>>,
}

impl SnowflakeRow {
pub fn get<T: SnowflakeDecode>(&self, column_name: &str) -> Result<T> {
let index = self
.column_names
let column_type = self
.column_types
.get(&column_name.to_ascii_uppercase())
.ok_or_else(|| Error::Decode(format!("column not found: {}", column_name)))?;
self.row[*index].try_get()
self.row[column_type.index].try_get()
}
pub fn column_names(&self) -> Vec<&str> {
self.column_names.iter().map(|(k, _)| k.as_str()).collect()
self.column_types.iter().map(|(k, _)| k.as_str()).collect()
}
pub fn column_types(&self) -> Vec<SnowflakeColumn> {
let column_types = self.column_types.deref();
let mut v: Vec<_> = column_types
.iter()
.map(|(k, v)| SnowflakeColumn {
name: k.clone(),
column_type: v.clone(),
})
.collect();
// sort by column index
v.sort_by_key(|c| c.column_type.index);
v
}
}

Expand Down
18 changes: 17 additions & 1 deletion tests/test-chunk.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use snowflake_connector_rs::{Result, SnowflakeAuthMethod, SnowflakeClient, SnowflakeClientConfig};
use snowflake_connector_rs::{
Result, SnowflakeAuthMethod, SnowflakeClient, SnowflakeClientConfig, SnowflakeColumnType,
};

#[tokio::test]
async fn test_download_chunked_results() -> Result<()> {
Expand Down Expand Up @@ -38,5 +40,19 @@ async fn test_download_chunked_results() -> Result<()> {
assert!(rows[0].column_names().contains(&"SEQ"));
assert!(rows[0].column_names().contains(&"RAND"));

let columns = rows[0].column_types();
assert_eq!(
columns[0].column_type.snowflake_type.to_ascii_uppercase(),
"FIXED"
);
assert_eq!(columns[0].column_type.nullable, false);
assert_eq!(columns[0].column_type.index, 0);
assert_eq!(
columns[1].column_type.snowflake_type.to_ascii_uppercase(),
"TEXT"
);
assert_eq!(columns[1].column_type.nullable, false);
assert_eq!(columns[1].column_type.index, 1);

Ok(())
}

0 comments on commit f168934

Please sign in to comment.