diff --git a/src/executor/execute.rs b/src/executor/execute.rs index 5a493457..842f15e2 100644 --- a/src/executor/execute.rs +++ b/src/executor/execute.rs @@ -1,8 +1,8 @@ use { - crate::{parse_sql::Query, Glue, Result, Row, Value}, + super::types::get_first_name, + crate::{parse_sql::Query, Glue, Result, Row}, serde::Serialize, - sqlparser::ast::{SetVariableValue, Statement}, - std::convert::TryInto, + sqlparser::ast::{ObjectType, Statement}, thiserror::Error as ThisError, }; @@ -57,6 +57,24 @@ impl Glue { let Query(statement) = statement; match statement { + Statement::CreateDatabase { + db_name, + if_not_exists, + location, + .. + } => { + if !self.try_extend_from_path( + db_name.0[0].value.clone(), + location + .clone() + .ok_or(ExecuteError::InvalidDatabaseLocation)?, + )? && !if_not_exists + { + Err(ExecuteError::DatabaseExists(db_name.0[0].value.clone()).into()) + } else { + Ok(Payload::Success) + } + } //- Modification //-- Tables Statement::CreateTable { @@ -73,10 +91,20 @@ impl Glue { names, if_exists, .. - } => self - .drop(object_type, names, *if_exists) - .await - .map(|_| Payload::DropTable), + } => match object_type { + ObjectType::Schema => { + // Schema for now // TODO: sqlparser-rs#454 + if !self.reduce(&get_first_name(names)?) && !if_exists { + Err(ExecuteError::ObjectNotRecognised.into()) + } else { + Ok(Payload::Success) + } + } + object_type => self + .drop(object_type, names, *if_exists) + .await + .map(|_| Payload::DropTable), + }, #[cfg(feature = "alter-table")] Statement::AlterTable { name, operation } => self .alter_table(name, operation) @@ -128,20 +156,14 @@ impl Glue { //- Context Statement::SetVariable { variable, value, .. - } => { - let first_value = value.get(0).unwrap(); // Why might one want anything else? - let value: Value = match first_value { - SetVariableValue::Ident(..) => unimplemented!(), - SetVariableValue::Literal(literal) => literal.try_into()?, - }; - let name = variable.value.clone(); - self.get_mut_context()?.set_variable(name, value); - Ok(Payload::Success) - } + } => self + .set_variable(variable, value) + .await + .map(|_| Payload::Success), Statement::ExplainTable { table_name, .. } => self.explain(table_name).await, - Statement::CreateDatabase { .. } => unreachable!(), // Handled at Glue interface // TODO: Clean up somehow + Statement::Execute { name, parameters } => self.procedure(name, parameters).await, _ => Err(ExecuteError::QueryNotSupported.into()), } } diff --git a/src/executor/mod.rs b/src/executor/mod.rs index d3cf5531..deada527 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -3,8 +3,10 @@ mod alter_table; mod execute; mod fetch; mod other; +mod procedure; mod query; mod recipe; +mod set_variable; mod types; pub use { diff --git a/src/executor/procedure.rs b/src/executor/procedure.rs new file mode 100644 index 00000000..d7129666 --- /dev/null +++ b/src/executor/procedure.rs @@ -0,0 +1,25 @@ +use { + crate::{ExecuteError, Glue, Payload, Result}, + sqlparser::ast::{Expr, Ident, Value as AstValue}, +}; + +impl Glue { + pub async fn procedure(&mut self, name: &Ident, parameters: &[Expr]) -> Result { + return match name.value.as_str() { + "FILE" => { + if let Some(Ok(query)) = parameters.get(0).map(|path| { + if let Expr::Value(AstValue::SingleQuotedString(path)) = path { + std::fs::read_to_string(path).map_err(|_| ()) + } else { + Err(()) + } + }) { + self.execute(&query) + } else { + Err(ExecuteError::InvalidFileLocation.into()) + } + } + _ => Err(ExecuteError::Unimplemented.into()), + }; + } +} diff --git a/src/executor/set_variable.rs b/src/executor/set_variable.rs new file mode 100644 index 00000000..b852e16c --- /dev/null +++ b/src/executor/set_variable.rs @@ -0,0 +1,21 @@ +use { + crate::{ExecuteError, Glue, Result, Value}, + sqlparser::ast::{Ident, SetVariableValue}, +}; + +impl Glue { + pub async fn set_variable( + &mut self, + variable: &Ident, + value: &[SetVariableValue], + ) -> Result<()> { + let first_value = value.get(0).ok_or(ExecuteError::MissingComponentsForSet)?; + let value: Value = match first_value { + SetVariableValue::Ident(..) => unimplemented!(), + SetVariableValue::Literal(literal) => literal.try_into()?, + }; + let name = variable.value.clone(); + self.get_mut_context()?.set_variable(name, value); + Ok(()) + } +} diff --git a/src/executor/types.rs b/src/executor/types.rs index 106c3b6b..c4e9e794 100644 --- a/src/executor/types.rs +++ b/src/executor/types.rs @@ -1,5 +1,5 @@ use { - crate::{JoinError, Result, Value}, + crate::{ExecuteError, JoinError, Result, Value}, serde::Serialize, sqlparser::ast::{ObjectName as AstObjectName, TableFactor}, std::fmt::Debug, @@ -18,6 +18,13 @@ pub struct ColumnInfo { pub index: Option, } +pub(crate) fn get_first_name(names: &[AstObjectName]) -> Result { + names + .get(0) + .and_then(|name| name.0.get(0).map(|name| name.value.clone())) + .ok_or(ExecuteError::ObjectNotRecognised.into()) +} + #[derive(Debug, Clone, PartialEq, Serialize)] pub struct ComplexTableName { pub database: Option, diff --git a/src/glue/mod.rs b/src/glue/mod.rs index f175514c..d7c6e39c 100644 --- a/src/glue/mod.rs +++ b/src/glue/mod.rs @@ -8,8 +8,7 @@ use { }, futures::executor::block_on, sqlparser::ast::{ - Expr, Ident, ObjectName, ObjectType, Query as AstQuery, SetExpr, Statement, - Value as AstValue, Values, + Expr, Ident, ObjectName, Query as AstQuery, SetExpr, Statement, Value as AstValue, Values, }, std::{collections::HashMap, fmt::Debug}, }; @@ -104,10 +103,10 @@ impl Glue { /// .expect("Storage Creation Failed"); /// let mut other_glue = Glue::new(String::from("other"), other_storage); /// - /// glue.extend(vec![other_glue]); + /// glue.extend_many_glues(vec![other_glue]); /// ``` /// - pub fn extend(&mut self, glues: Vec) { + pub fn extend_many_glues(&mut self, glues: Vec) { self.databases.extend( glues .into_iter() @@ -119,6 +118,53 @@ impl Glue { .databases, ) } + pub fn extend_glue(&mut self, glue: Glue) { + self.databases.extend(glue.databases) + } + + /// Extend using a ~~[Path]~~ [String] which represents a path + /// Guesses the type of database based on the extension + /// Returns [bool] of whether action was taken + pub fn try_extend_from_path( + &mut self, + database_name: String, + database_path: String, + ) -> Result { + if self.databases.contains_key(&database_name) { + return Ok(false); + } + let connection = if database_path.ends_with('/') { + Connection::Sled(database_path) + } else if database_path.ends_with(".csv") { + Connection::CSV(database_path, CSVSettings::default()) + } else if database_path.ends_with(".xlsx") { + Connection::Sheet(database_path) + } else { + return Err(ExecuteError::InvalidDatabaseLocation.into()); + }; + let database = connection.try_into()?; + Ok(self.extend(database_name, database)) + } + + /// Extend [Glue] by single database + /// Returns [bool] of whether action was taken + pub fn extend(&mut self, database_name: String, database: Storage) -> bool { + let database_present = self.databases.contains_key(&database_name); + if !database_present { + self.databases.insert(database_name, database); + } + !database_present + } + + /// Opposite of [Glue::extend], removes database + /// Returns [bool] of whether action was taken + pub fn reduce(&mut self, database_name: &String) -> bool { + let database_present = self.databases.contains_key(database_name); + if database_present { + self.databases.remove(database_name); + } + database_present + } } /// Internal: Modify @@ -165,75 +211,6 @@ impl Glue { } /// Will execute a pre-parsed query (see [Glue::pre_parse()] for more). pub fn execute_parsed(&mut self, query: Query) -> Result { - if let Query(Statement::CreateDatabase { - db_name, - if_not_exists, - location, - .. - }) = query - { - let store_name = db_name.0[0].value.clone(); - return if self.databases.iter().any(|(store, _)| store == &store_name) { - if if_not_exists { - Ok(Payload::Success) - } else { - Err(ExecuteError::DatabaseExists(store_name).into()) - } - } else { - match location { - None => Err(ExecuteError::InvalidDatabaseLocation.into()), // TODO: Memory - Some(location) => { - let store = if location.ends_with('/') { - Connection::Sled(location).try_into()? - } else if location.ends_with(".csv") { - Connection::CSV(location, CSVSettings::default()).try_into()? - } else if location.ends_with(".xlsx") { - Connection::Sheet(location).try_into()? - } else { - return Err(ExecuteError::InvalidDatabaseLocation.into()); - }; - self.extend(vec![Glue::new(store_name, store)]); - Ok(Payload::Success) - } - } - }; - } else if let Query(Statement::Execute { name, parameters }) = query { - return match name.value.as_str() { - "FILE" => { - if let Some(Ok(query)) = parameters.get(0).map(|path| { - if let Expr::Value(AstValue::SingleQuotedString(path)) = path { - std::fs::read_to_string(path).map_err(|_| ()) - } else { - Err(()) - } - }) { - self.execute(&query) - } else { - Err(ExecuteError::InvalidFileLocation.into()) - } - } - _ => Err(ExecuteError::Unimplemented.into()), - }; - } else if let Query(Statement::Drop { - object_type: ObjectType::Schema, // FOR NOW! // TODO: sqlparser-rs#454 - if_exists, - names, - .. - }) = query - { - let database_name = names - .get(0) - .and_then(|name| name.0.get(0).map(|name| name.value.clone())) - .ok_or(ExecuteError::ObjectNotRecognised)?; - - if self.databases.contains_key(&database_name) { - self.databases.remove(&database_name); - } else if !if_exists { - return Err(ExecuteError::ObjectNotRecognised.into()); - } - return Ok(Payload::Success); - } - block_on(self.execute_query(&query)) } /// Provides a parsed query to execute later. diff --git a/src/storages/sheet_storage/store.rs b/src/storages/sheet_storage/store.rs index d03a5182..bf174d01 100644 --- a/src/storages/sheet_storage/store.rs +++ b/src/storages/sheet_storage/store.rs @@ -32,7 +32,7 @@ impl Store for SheetStorage { let rows = vec![vec![None; col_count as usize]; (row_count as usize) - 1]; let rows = sheet .get_collection_to_hashmap() - .into_iter() + .iter() .filter(|((row, _col), _)| row != &1) .fold(rows, |mut rows, ((row_num, col_num), cell)| { rows[(row_num - 2) as usize][(col_num - 1) as usize] = Some(cell.clone()); @@ -53,7 +53,7 @@ impl Store for SheetStorage { cell.map(|cell| cell.get_value().to_string()) .unwrap_or_default(), ) - .cast_valuetype(&data_type) + .cast_valuetype(data_type) .unwrap_or(Value::Null) }) .collect()),