Initial commit

This commit is contained in:
Dan Spencer 2015-03-10 08:07:47 -06:00
commit 0e112cfc1c
22 changed files with 3009 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
target
Cargo.lock
*.db

5
Cargo.toml Normal file
View File

@ -0,0 +1,5 @@
[package]
name = "llamadb"
version = "0.0.1"
authors = ["Dan Spencer <dan@atomicpotato.net>"]

84
README.md Normal file
View File

@ -0,0 +1,84 @@
# LlamaDB
**Warning**: This project is in the design/implementation phase, and is not
functional. Do NOT use this for anything you depend on!
LlamaDB is a versatile, speedy and low-footprint SQL database, written entirely
in the Rust programming language.
# Another SQL database? Why?
The project is driven by two personal goals:
1. Understand SQL better: both the language and the implementation details
2. Write a mission-critical project with Rust
# Example
Note: DOESN'T WORK YET. Intended design.
```rust
let mut db = llamadb::MemoryDatabase::new();
sql!(db,
"CREATE TABLE account (
id UUID PRIMARY KEY,
username VARCHAR,
password BCRYPT
);"
).unwrap();
let row = sql!(db,
"INSERT INTO account(username, password) VALUES(?, ?);",
"John", bcrypt("secretpassword")
).unwrap();
println!("{}", row["id"].getString());
```
## Principles
1. Keep it simple, stupid!
* No triggers, no users, no embedded SQL languages (e.g. PL/SQL).
2. No temporary files
* Temporary files can create unknown security violations if the user is unaware
of them. For instance, SQLite will create temporary files on the OS's /tmp
directory and in the same path as the SQLite database file.
* The user may expect all database files to be on an encrypted drive.
Sensitive data shouldn't leak to unexpected places.
## API design
When using the Rust or C API, all columns are sent and received as _byte arrays_.
It is up to the client-side driver to convert these types to numbers or strings if applicable.
## Security principles
1. The database should NEVER receive plain-text passwords.
* This is why LlamaDB does not include any hashing functions;
the inclusion of hashing functions would encourage the sending of plain-text
passwords over a database connection.
* If the DBAs and developers aren't careful, the passwords could be logged by
the DBMS through query logging.
* Hashing algorithms such as bcrypt (should) require a CRNG for the salt, and
embedding CRNGs is not in the scope of LlamaDB.
* Nowadays, it's easy to perform application-side hashing. There's a vast
supply of good crypto libraries for every important programming language.
## NULL
Unlike standard SQL, NULL is opt-in on table creation.
For users of other SQL databases, just think of all `CREATE TABLE` columns as
having an implicit `NOT NULL`.
Null still exists as a placeholder value and for outer joins; it's just not the
default for `CREATE TABLE` columns.
If NULL is desired for a column, add the `NULL` constraint.
## Special thanks
A **HUGE THANKS** goes out to SQLite and the SQLite documentation.
Their wonderful docs helped shed some light on the SQL syntax and other crucial
details such as their B-Tree implementation.

115
docs/B+Tree-Pseudocode.md Normal file
View File

@ -0,0 +1,115 @@
# Pseudocode
To be used as a reference for implementation.
## Key comparison
B+Tree key comparison is done in a way that avoids loading overflow pages.
```
# Compares a cell with a key. Returns <, =, or >
compare_cell(cell, key):
v = cell.in_page_payload.truncate_to(key.length)
# Compare the bytes of v and key
return v.compare(key) if not Equal
return Equal if cell.payload_length <= key.length
# Sad path: we need to read overflow pages.
for overflow_page in overflow pages:
v = overflow_page.truncate_to(key.length - overflow_offset)
return v.compare(key[overflow_offset..]) if not Equal
# Every overflow page compared equal. The entire payload is equal.
return Equal
```
## Key search
```
find_keys(root_page_id, min, max):
(right_pointer, cells) = find_min_leaf_offset(root_page_id, min)
loop:
for cell in cells:
if cell > max:
return
read full cell data
yield cell data
# End of leaf node. Keep traversing right.
if right_pointer:
page = read right pointer page
right_pointer = parse page header.right_pointer
cells = parse page cells
else:
return
find_min_leaf_offset(page_id, min):
right_pointer = (parse page header).right_pointer
cells = parse page cells
if leaf:
offset = offset where cell >= min
if offset found:
return (right_pointer, cells[offset..])
else if right_pointer:
# The offset is definitely the first cell on the right page.
right_page_header = parse right pointer page header
return (right_page_header.right_pointer, read right page cells)
else:
# No offset can be found
return None
else:
for cell in cells:
if cell <= min:
return find_min_leaf_offset(cell.left_pointer)
return find_min_leaf_offset(right_pointer)
```
## Key insertion
```
insert_key_to_root(root_page_id, key):
if root page is leaf:
insert_key_to_root_leaf(root_page_id, key)
else:
if let Some(insert_to_parent) = insert_key_to_nonroot(page_id, key):
insert_key_to_nonroot(page_id, key):
if page is leaf:
insert_key_to_nonroot_leaf(page_id, key)
else:
cells = parse page cells
insert_key_to_root_leaf(page_id, key):
insert_key_to_nonroot_leaf(page_id, key):
cells = parse page cells
insert key into cells
if cells.length > MAX_CELLS:
# Split the node
# Get split offset. Uses truncating division.
split_offset = cells.length / 2
new_page_id = create new page
copy cells[..split_offset] to page_id
copy cells[split_offset..] to new_page_id
new_page_id.right_pointer = page_id.right_pointer
page_id.right_pointer = new_page_id
# Tell the caller about the left and right page ids, and the split key
return Some(page_id, new_page_id, cells[split_offset])
else:
copy cells to page_id
return None
```

109
docs/B+Tree.md Normal file
View File

@ -0,0 +1,109 @@
# B+Tree
The only index that LlamaDB supports so far is the B+Tree index.
The B+Tree is a data structure optimized for ranged searches, insertions
and deletions.
The B+Tree was chosen instead of the B-Tree because it supports direct traversal
over leaf nodes when iterating a range of keys.
For example, performing a `SELECT * FROM table;` will not repeatedly climb up
and down interior pages, as would happen with B-Trees.
Each B+Tree has a property for Cell length.
Any key in excess of this length has its remainder put into overflow pages.
Each B+Tree cell has a fixed-length.
The intent of using fixed-length cells instead of variable-length cells is to
simplify searches, node splitting and insertions.
While it's true that fixed-length cells may lead to wasted space, the real-world
problems that arise from this is likely insignificant.
As a general rule of thumb, LlamaDB values performance and simplicity over
saving a few bytes of disk space.
## B+Tree page structure
### Header, 16 bytes
| Size and type | Name |
|---------------|---------------------------------------------|
| 1, u8-le | Flags |
| 3 | unused/reserved |
| 2, u16-le | Page cell count |
| 2, u16-le | Cell length |
| 8, u64-le | Right page pointer |
**Flags: 000000RL**
* R: Root page. 0 if non-root page, 1 if root page.
* L: Leaf page. 0 if interior page, 1 if leaf page.
For leaf nodes, the right page pointer is a reference to the next leaf node for traversal.
The referenced leaf node contains keys that are sorted after ones from the current leaf node.
All leaf nodes form a linked list that contain all keys in the B+Tree.
If the page is the last leaf page of the B+Tree, there is no right page pointer.
In this case, the right page pointer is set to a value of 0.
In all other cases, the right page pointer must be set to a valid page id.
Cell length is the fixed length of the Cell structure.
Cell length must be a minimum of 24 bytes: 20 byte header + minimum 4-byte in-page payload.
All child B+Tree pages must have the same cell length as the root B+Tree page.
This invariant is useful for node splitting: a cell can then simply be moved
byte-for-byte into a new page without worrying about incompatible cell lengths.
Let P = Page size. Let L = Cell length. Let C = Max page cell count.
* C = floor((P - 16) / L)
* L = floor((P - 16) / C)
* C has the minimum: 2
* C has the maximum: floor((P - 16) / 24, 2)
* L has the minimum: 24
* L has the maximum: (P - 16) / 2
| Page size | Cell length | Max cell count per page |
|-----------|-------------|-------------------------|
| 65536 | 24 | 2730 |
| 65536 | 32760 | 2 |
| 4096 | 24 | 170 |
| 4096 | 2040 | 2 |
| 512 | 24 | 20 |
| 512 | 248 | 2 |
| 64 | 24 | 2 |
Note that 65536 is the largest allowed page size,
and 64 is the smallest allowed page size.
### Cell
| Size and type | Name |
|-------------------|------------------------------------------|
| 8, u64-le | Left page pointer (ignored if leaf page) |
| 4, u32-le | Payload length |
| 8, u64-le | _Overflow page (omitted if not needed)_ |
| Remainder of cell | In-page payload |
The left page pointer is _ignored_ instead of _omitted_ for leaf pages.
This is to avoid issues in the event that a leaf page is converted to an
interior page.
Rationale:
If the left page pointer were omitted for leaf pages, the pointer would need to
be added back when the cell is converted for an interior page. The cell length
is always fixed, so in the event that the cell also has overflow data,
all of the overflow data _and all of its pages_ would need to be shifted
by 8 bytes.
The current solution doesn't need to read the overflow pages, which is better
for caching.
If the payload length is less than the remainder of the cell, the data is
padded with zeros.
## Insertion
TODO

94
docs/Column Types.md Normal file
View File

@ -0,0 +1,94 @@
# Column Types
## Primitive data types
LlamaDB's column types are designed to be orthogonal and non-ambiguous in their
use.
Primitive data types are the building blocks on which all LlamaDB column types
are built.
Primitive data types are meant to convey what is literally stored in the
database, not how said data is serialized or deserialized.
* `byte` - An octet, contains 8 bits of data.
* `T[]` - A variable-length array of types.
* The equivalent of `BLOB` from MySQL would be `byte[]`
* `T[N]` - A fixed-length array of types, where `N` is a constant, positive integer.
* The equivalent of `BINARY(N)` from MySQL would be `byte[N]`
## Abstract types
Abstract types are typed wrappers for `byte[]` or `byte[N]` primitives.
Abstract types address the following use cases:
* Input validation
* Serialization and deserialization for the SQL language
Types:
* `uX` - Unsigned integer
* Backing primitive: `byte[Y]`, where `Y` = ceiling(`X` / 8)
* `iX` - Signed integer
* Backing primitive: `byte[Y]`, where `Y` = ceiling(`X` / 8)
* Range is two's complement. e.g. (`u8` has a range of -128 to +127)
* `f32` - A floating point number
* Backing primitive: `byte[4]`
* `char` - A Unicode character; a code point
* Backing primitive: `byte[4]`
* `string` - A UTF-8 encoded string.
* Backing primitive: `byte[]`
* The specified length (if any) is the maximum character length, not the byte length.
* `json` - Serialized JSON, useful for document stores as seen in NoSQL databases.
* Backing primitive: `byte[]`
* Validated and serialized using MessagePack
* Canonical serialization, can be compared for equality
* `bool` - A boolean
* Backing primitive: `byte[1]`
* `true` on 1, `false` on 0
* `uuid` - A universally unique identifier.
* Backing primitive: `byte[16]`
* `bcrypt` - A Bcrypt hash
* Backing primitive: `byte[40]`
* Serialized using Bcrypt BMCF Definition: https://github.com/ademarre/binary-mcf#bcrypt-bmcf-definition
* 8-bit header, 128-bit salt, 184-bit digest (NOT 192-bit digest!)
* If a 192-bit digest is provided, the last 8 bits will be discarded.
This is due to a bug in the original bcrypt implementation that discards the
last 8 bits on stringification.
* The database simply stores bcrypt hashes; it cannot perform any hashing algorithms.
* `scrypt`
* `pbkdf2`
Any `byte[]` or `byte[N]` column can be converted to alternative representations:
```sql
SELECT bytes_column from MY_TABLE;
# {DE AD BE EF} (byte[])
SELECT bytes_column.hex from MY_TABLE;
# 'DEADBEEF' (varchar)
SELECT bytes_column.base64 from MY_TABLE;
# '3q2+7w==' (varchar)
INSERT INTO bytes_column VALUES ({DE AD C0 DE});
INSERT INTO bytes_column.hex VALUES ('DEADC0DE');
```
All abstract types have accessors to the backing primitive:
```sql
SELECT json_column from MY_TABLE;
# { "is_json": true } (json)
SELECT json_column.raw from MY_TABLE;
# {81 A7 69 73 5F 6A 73 6F 6E C3} (byte[])
SELECT json_column.raw.hex from MY_TABLE;
# '81A769735F6A736F6EC3' (varchar)
SELECT json_column.raw.base64 from MY_TABLE;
# 'gadpc19qc29uww==' (varchar)
```

View File

@ -0,0 +1,124 @@
# Indexing and Sorting
This is perhaps the most important implementation problem that SQL databases
must address.
## Simple and ignorant
All sorting is done with simple `memcpy()` operations.
This means that all keys' byte representations sort the same way as the keys
do semantically.
The B+Tree traversal algorithm is kept simpler this way.
The algorithm doesn't need to be aware of the types contained in the keys, so
there's no need for specialized comparators.
To the traversal algorithm, all keys are simple byte collections that are always
ordered the same way.
## Byte sorting
All keys are stored and sorted as a collection of bytes.
Here's a sorted byte list:
```
00
00 00
00 00 FF
00 01
01
02 00
...
FE FF FF FF FF FF FF
FF
FF 00
FF FF
FF FF FF
FF FF FF FF
```
Keys that share the same beginning as another key but are longer are sorted after.
## Integers
All integer keys are stored as big-endian.
If the integer is signed, then add half of the unsigned maximum (8-bit => 128).
* 255 unsigned 4-byte => `00 00 00 FF`
* -32768 signed 2-byte => `00 00`
* -1 signed 2-byte => `7F FF`
* 0 signed 2-byte => `80 00`
* 32767 signed 2-byte => `FF FF`
## Strings
All string keys are stored as UTF-8 and are null-terminated.
A length is not prefixed because this would effectively make the strings sorted
by length instead of lexicographically.
UTF-8 has the property of lexicographic sorting. Even with extension bytes,
the string will sort in ascending order of the code points.
The null terminator is needed to indicate the end of the string.
It also serves as a separator from other multi-column values in the key.
Longer strings that share the same beginning as another string are sorted after.
```
41 70 70 6C 65 00 // Apple
41 70 70 6C 65 73 00 // Apples
41 CC 88 70 66 65 6C 00 // Äpfel (NFD)
42 61 6E 61 6E 61 00 // Banana
42 61 6E 61 6E 61 73 00 // Bananas
42 61 6E 64 00 // Band
42 65 65 68 69 76 65 00 // Beehive
42 65 65 73 00 // Bees
61 70 70 6C 65 00 // apple
C3 84 70 66 65 6C 00 // Äpfel (NFC)
```
* `WHERE x LIKE 'Apple%'` => `41 70 70 6C 65`
* `WHERE x = 'Apple'` => `41 70 70 6C 65 00`
Strings are sorted by their UTF-8 representation, and not with a collation
algorithm.
It's theoretically possible to index strings using a collation algorithm if
the algorithm can return a byte representation that sorts the same way.
However, this is not yet supported.
## Floating point numbers
This encoding is mostly compatible with the number ranges from IEEE 754.
The only exception is NaN, which this encoding does not support.
NaN is unsortable/imcomparable, and therefore cannot be encoded.
This encoding is basically the same as binary32 IEEE 754, but with flipped bits.
Like the integer types, the encoding is in big-endian
(the byte with the sign bit comes first).
To convert IEEE 754 to or from this encoding:
* If the number is negative, flip all the bits.
* If the number is positive, flip the sign bit.
This way, an encoding of `00 7F FF FF` is a negative number with the highest exponent and the highest mantissa,
which would be the smallest possible floating point number.
Similarly, an encoding of `FF 80 00 00` is a positive number with the highest exponent and the highest mantissa,
which would be the largest possible floating point number.
* -inf => `00 7F FF FF`
* -1 => `40 7F FF FF`
* -0 => `7F FF FF FF`
* +0 => `80 00 00 00`
* +1 => `BF 80 00 00`
* +inf => `FF 80 00 00`
The removal of NaN disqualifies 16,777,214 values.
Ranges that the removal of NaN disqualifies (inclusive):
* `00 00 00 00` to `00 7F FF FE`
* `FF 80 00 01` to `FF FF FF FF`

27
docs/Pager.md Normal file
View File

@ -0,0 +1,27 @@
# Pager
The pager module partitions a backing store into a cache-friendly, addressable,
and fixed-sized collection of pages.
## Pager implementations
There are two pager implementations: **Disk** and **Memory**.
* A disk pager ideally has a page size that matches the device's sector size.
* This is usually 512 or 4096 bytes.
* A memory pager ideally has a page size that matches the CPU's cache line size.
* On most architectures, this is 4096 bytes.
For the most part, the disk and memory pagers have a lot in common.
Both pagers' backing stores are segmented at the hardware level, and if
exploited can yield faster data access through caching.
This means the same pager abstractions can be used for both disk and memory
without the abstractions being too leaky.
## Invariants
* Page ID can be any unique value except for 0.
* The minimum page size: **64 bytes**.
* The maximum page size: **65536 bytes**.
* The page size must be a power of 2.

11
src/btree/bound.rs Normal file
View File

@ -0,0 +1,11 @@
pub enum Bound<'a> {
Included(&'a [u8]),
Excluded(&'a [u8]),
Unbounded
}
#[derive(PartialEq)]
pub enum Order {
Ascending,
Descending
}

105
src/btree/cell.rs Normal file
View File

@ -0,0 +1,105 @@
use byteutils;
#[derive(Debug, PartialEq)]
pub enum BTreeCellErr {
CellLengthTooSmall(usize),
PayloadLengthTooSmall(u32)
}
#[derive(Debug, PartialEq)]
pub struct BTreeCell<'a> {
pub left_page: u64,
pub payload_length: u32,
pub overflow_page: Option<u64>,
pub in_page_payload: &'a [u8]
}
impl<'a> BTreeCell<'a> {
/// Returns None if the data is corrupt
pub fn read(data: &'a [u8]) -> Result<BTreeCell<'a>, BTreeCellErr> {
if data.len() < 24 {
return Err(BTreeCellErr::CellLengthTooSmall(data.len()));
}
let left_page = byteutils::read_u64_le(&data[0..8]);
let payload_length = byteutils::read_u32_le(&data[8..12]);
if payload_length < 4 {
return Err(BTreeCellErr::PayloadLengthTooSmall(payload_length));
}
if payload_length as usize > data.len() - 12 {
let overflow_page = byteutils::read_u64_le(&data[12..20]);
Ok(BTreeCell {
left_page: left_page,
payload_length: payload_length,
overflow_page: Some(overflow_page),
in_page_payload: &data[20..]
})
} else {
Ok(BTreeCell {
left_page: left_page,
payload_length: payload_length,
overflow_page: None,
in_page_payload: &data[12..12+payload_length as usize]
})
}
}
}
#[cfg(test)]
mod test {
use super::BTreeCell;
#[test]
fn test_btree_cell_unused() {
// Cell with unused data
assert_eq!(BTreeCell::read(&[
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
5, 0, 0, 0,
9, 8, 7, 6, 5,
0, 0, 0, 0, 0, 0, 0 // unused data (and padding to 24 bytes)
]).unwrap(), BTreeCell {
left_page: 0x0807060504030201,
payload_length: 5,
overflow_page: None,
in_page_payload: &[9, 8, 7, 6, 5]
});
}
#[test]
fn test_btree_cell_overflow() {
// Cell with overflow
assert_eq!(BTreeCell::read(&[
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
27, 0, 0, 0,
0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10,
9, 8, 7, 6, 5, 4, 3, 2, 1, 0
]).unwrap(), BTreeCell {
left_page: 0x0807060504030201,
payload_length: 27,
overflow_page: Some(0x100F0E0D0C0B0A09),
in_page_payload: &[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
});
}
#[test]
fn test_btree_cell_corrupt() {
use super::BTreeCellErr::*;
// Cell length is too small
assert_eq!(BTreeCell::read(&[
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
5, 0, 0, 0,
9, 8, 7, 6, 5
]).unwrap_err(), CellLengthTooSmall(17));
// Payload length is too small
assert_eq!(BTreeCell::read(&[
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
3, 0, 0, 0,
9, 8, 7,
0, 0, 0, 0, 0, 0, 0, 0, 0 // unused data (and padding to 24 bytes)
]).unwrap_err(), PayloadLengthTooSmall(3));
}
}

151
src/btree/mod.rs Normal file
View File

@ -0,0 +1,151 @@
mod bound;
mod cell;
mod page;
pub use self::bound::*;
use super::pager::Pager;
mod btree_consts {
pub const MAX_DEPTH: usize = 256;
}
#[derive(Debug)]
pub enum BTreeError<P: Pager> {
PagerError(P::Err),
BTreePageError(self::page::BTreePageErr)
}
fn pager_error<P: Pager>(err: P::Err) -> BTreeError<P> {
BTreeError::PagerError(err)
}
pub struct BTreeCollection<'a, P: Pager + 'a> {
pager: &'a P
}
impl<'a, P: Pager + 'a> BTreeCollection<'a, P> {
pub fn new<'l>(pager: &'l P) -> BTreeCollection<'l, P> {
BTreeCollection {
pager: pager
}
}
pub fn get_btree(&self, page_id: u64) -> Result<BTree<P>, BTreeError<P>>
{
try!(self.pager.check_page_id(page_id).map_err(pager_error::<P>));
Ok(BTree {
pager: self.pager,
root_page_id: page_id
})
}
}
pub struct BTreeCollectionMut<'a, P: Pager + 'a> {
pager: &'a mut P
}
impl<'a, P: Pager + 'a> BTreeCollectionMut<'a, P> {
pub fn new<'l>(pager: &'l mut P) -> BTreeCollectionMut<'l, P> {
BTreeCollectionMut {
pager: pager
}
}
pub fn new_btree(&mut self, cell_length: u16) -> Result<u64, BTreeError<P>>
{
let page_id = try!(self.pager.new_page(|buffer| {
let write_result = page::BTreePageWrite {
root: true,
leaf: true,
page_cell_count: 0,
cell_length: cell_length,
right_page: None
}.write(&mut buffer[0..16]);
}).map_err(pager_error::<P>));
Ok(page_id)
}
pub fn get_btree(&mut self, page_id: u64) -> Result<BTree<P>, BTreeError<P>>
{
try!(self.pager.check_page_id(page_id).map_err(pager_error::<P>));
Ok(BTree {
pager: self.pager,
root_page_id: page_id
})
}
pub fn mut_btree(&mut self, page_id: u64) -> Result<BTreeMut<P>, BTreeError<P>>
{
try!(self.pager.check_page_id(page_id).map_err(pager_error::<P>));
Ok(BTreeMut {
pager: self.pager,
root_page_id: page_id
})
}
pub fn remove_btree(&mut self, page_id: u64) -> Result<(), BTreeError<P>>
{
try!(self.pager.check_page_id(page_id).map_err(pager_error::<P>));
try!(self.pager.mark_page_as_removed(page_id).map_err(pager_error::<P>));
unimplemented!()
}
}
/// The B+Tree is assumed to be ordered byte-wise.
pub struct BTree<'a, P: Pager + 'a> {
pager: &'a P,
root_page_id: u64
}
impl<'a, P: Pager + 'a> BTree<'a, P> {
pub fn find_keys(&self, min: Bound, max: Bound, order: Order) -> Result<BTreeKeyIter, BTreeError<P>>
{
// Assume Ascending order
if order == Order::Descending { unimplemented!() }
unimplemented!()
}
}
pub struct BTreeMut<'a, P: Pager + 'a> {
pager: &'a mut P,
root_page_id: u64
}
impl<'a, P: Pager + 'a> BTreeMut<'a, P> {
pub fn insert_key(&mut self, key: &[u8]) -> Result<(), BTreeError<P>>
{
// Non-root page:
// If the page overflows, split it so that the first half remains in the
// current page, and the second half is put in a new page.
// Insert the middle key into the parent page.
//
// Root page:
// If the page overflows, split it into two new pages.
// Clear the current page and insert the middle key into the current page.
//
// On split, if the cell count is odd,
unimplemented!()
}
pub fn update_keys<F>(&mut self, min: Bound, max: Bound, transform: F) -> Result<(), BTreeError<P>>
where F: FnMut(&[u8], &mut Vec<u8>)
{
unimplemented!()
}
pub fn remove_keys(&mut self, min: Bound, max: Bound) -> Result<(), BTreeError<P>>
{
unimplemented!()
}
}
pub struct BTreeKeyIter;

154
src/btree/page.rs Normal file
View File

@ -0,0 +1,154 @@
use byteutils;
#[derive(Debug, PartialEq)]
pub enum BTreePageErr {
PageLengthWrong(usize),
HeaderLengthWrong(usize),
InteriorMustContainRightPage
}
mod consts {
pub const MIN_PAGE_LENGTH: usize = 1 << 6;
pub const ROOT_PAGE: u8 = 0b0000_0010;
pub const LEAF_PAGE: u8 = 0b0000_0001;
}
#[derive(Debug, PartialEq)]
pub struct BTreePageRead<'a> {
pub root: bool,
pub leaf: bool,
pub page_cell_count: u16,
pub cell_length: u16,
pub right_page: Option<u64>,
pub data: &'a [u8]
}
impl<'a> BTreePageRead<'a> {
pub fn read(data: &'a [u8]) -> Result<BTreePageRead<'a>, BTreePageErr> {
use std::num::Int;
// Ensure the page length is a power of two and is the minimum page length
if !(data.len().count_ones() == 1 && data.len() >= consts::MIN_PAGE_LENGTH) {
return Err(BTreePageErr::PageLengthWrong(data.len()));
}
let flags = data[0];
let page_cell_count = byteutils::read_u16_le(&data[4..6]);
let cell_length = byteutils::read_u16_le(&data[6..8]);
let root = flags & consts::ROOT_PAGE != 0;
let leaf = flags & consts::LEAF_PAGE != 0;
// TODO: check for more invariants, such as page_cell_count and cell_length
let right_page = match byteutils::read_u64_le(&data[8..16]) {
0 => {
// Make sure that this is a leaf node.
// "0" indicates the last leaf node of the B+Tree.
if leaf {
None
} else {
return Err(BTreePageErr::InteriorMustContainRightPage);
}
},
right_page => Some(right_page)
};
Ok(BTreePageRead {
root: root,
leaf: leaf,
page_cell_count: page_cell_count,
cell_length: cell_length,
right_page: right_page,
data: &data[16..]
})
}
pub fn to_write(&self) -> BTreePageWrite {
BTreePageWrite {
root: self.root,
leaf: self.leaf,
page_cell_count: self.page_cell_count,
cell_length: self.cell_length,
right_page: self.right_page,
}
}
}
pub struct BTreePageWrite {
pub root: bool,
pub leaf: bool,
pub page_cell_count: u16,
pub cell_length: u16,
pub right_page: Option<u64>
}
impl BTreePageWrite {
pub fn write(&self, data: &mut [u8]) -> Result<(), BTreePageErr> {
if data.len() != 16 {
return Err(BTreePageErr::HeaderLengthWrong(data.len()));
}
if !self.leaf && self.right_page.is_none() {
return Err(BTreePageErr::InteriorMustContainRightPage);
}
// TODO: check for more invariants, such as page_cell_count and cell_length
let right_page = match self.right_page {
None => 0,
Some(page) => page
};
let flags = {
let mut f = 0;
if self.root { f |= consts::ROOT_PAGE }
if self.leaf { f |= consts::LEAF_PAGE }
f
};
data[0] = flags;
data[1] = 0;
data[2] = 0;
data[3] = 0;
byteutils::write_u16_le(self.page_cell_count, &mut data[4..6]);
byteutils::write_u16_le(self.cell_length, &mut data[6..8]);
byteutils::write_u64_le(right_page, &mut data[8..16]);
Ok(())
}
}
#[cfg(test)]
mod test {
use super::{BTreePageRead, BTreePageWrite};
#[test]
fn test_btree_page_readwrite() {
let header_buf = [
0x02,
0, 0, 0,
5, 0,
24, 0,
0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C, 0x0E, 0x10
];
let mut page_buf: Vec<u8> = header_buf.to_vec();
page_buf.extend(0..128-16);
let page = BTreePageRead::read(page_buf.as_slice()).unwrap();
assert_eq!(page, BTreePageRead {
root: true,
leaf: false,
page_cell_count: 5,
cell_length: 24,
right_page: Some(0x100E0C0A08060402),
data: &page_buf.as_slice()[16..]
});
let mut write_header_buf = [0; 16];
page.to_write().write(&mut write_header_buf);
assert_eq!(header_buf, write_header_buf);
}
}

192
src/byteutils.rs Normal file
View File

@ -0,0 +1,192 @@
pub fn read_u16_le(buf: &[u8]) -> u16 {
assert_eq!(buf.len(), 2);
buf.iter().enumerate().fold(0, |prev, (i, v)| {
prev | ((*v as u16) << (i*8))
})
}
pub fn read_u32_le(buf: &[u8]) -> u32 {
assert_eq!(buf.len(), 4);
buf.iter().enumerate().fold(0, |prev, (i, v)| {
prev | ((*v as u32) << (i*8))
})
}
pub fn read_u64_le(buf: &[u8]) -> u64 {
assert_eq!(buf.len(), 8);
buf.iter().enumerate().fold(0, |prev, (i, v)| {
prev | ((*v as u64) << (i*8))
})
}
#[must_use = "must use returned length"]
pub fn read_uvar(buf: &[u8]) -> Option<(usize, u64)> {
let mut accum = 0;
for (i, v) in buf.iter().enumerate() {
let has_more = (v & 0x80) != 0;
accum = (accum << 7) | (*v as u64 & 0x7F);
if !has_more {
return Some((i+1, accum));
}
}
None
}
pub fn write_u16_le(value: u16, buf: &mut [u8]) {
assert_eq!(buf.len(), 2);
for (i, v) in buf.iter_mut().enumerate() {
let byte = ((value & (0xFF << (i*8))) >> (i*8)) as u8;
*v = byte;
}
}
pub fn write_u32_le(value: u32, buf: &mut [u8]) {
assert_eq!(buf.len(), 4);
for (i, v) in buf.iter_mut().enumerate() {
let byte = ((value & (0xFF << (i*8))) >> (i*8)) as u8;
*v = byte;
}
}
pub fn write_u64_le(value: u64, buf: &mut [u8]) {
assert_eq!(buf.len(), 8);
for (i, v) in buf.iter_mut().enumerate() {
let byte = ((value & (0xFF << (i*8))) >> (i*8)) as u8;
*v = byte;
}
}
/// Maximum buffer size needed for 64-bit number: 10 bytes
#[must_use = "must use returned length"]
pub fn write_uvar(value: u64, buf: &mut [u8]) -> Option<usize> {
let mut remainder = value;
for i in 0..buf.len() {
let data = (remainder & 0x7F) as u8;
remainder = remainder >> 7;
let has_more = remainder != 0;
buf[i] = if i == 0 {
data
} else {
0x80 | data
};
if !has_more {
// Reverse the buffer; most significant numbers should go first.
buf[0..i+1].reverse();
return Some(i + 1)
}
}
// The buffer wasn't long enough
None
}
#[cfg(test)]
mod test {
use super::{read_u16_le, read_u32_le, read_u64_le, read_uvar};
use super::{write_u16_le, write_u32_le, write_u64_le, write_uvar};
static TEST_U16: [(u16, &'static [u8]); 3] = [
(0x0201, &[0x01, 0x02]),
(0x0000, &[0x00, 0x00]),
(0xFFFF, &[0xFF, 0xFF]),
];
static TEST_U32: [(u32, &'static [u8]); 1] = [
(0x04030201, &[0x01, 0x02, 0x03, 0x04])
];
static TEST_U64: [(u64, &'static [u8]); 1] = [
(0x0807060504030201, &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08])
];
static TEST_UVAR: [(u64, &'static [u8]); 8] = [
(0x00, &[0x00]),
(0x7F, &[0x7F]),
(0x80, &[0x81, 0x00]),
(0xFF, &[0x81, 0x7F]),
(0x0100, &[0x82, 0x00]),
(0xFFFF_FFFF, &[0x8F, 0xFF, 0xFF, 0xFF, 0x7F]),
(0x1234_5678_9ABC_DEF0, &[0x92, 0x9A, 0x95, 0xCF, 0x89, 0xD5, 0xF3, 0xBD, 0x70]),
(0xFFFF_FFFF_FFFF_FFFF, &[0x81, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
];
#[test]
fn test_read_u16_le() {
for &(v, buf) in TEST_U16.iter() {
assert_eq!(v, read_u16_le(buf));
}
}
#[test]
fn test_read_u32_le() {
for &(v, buf) in TEST_U32.iter() {
assert_eq!(v, read_u32_le(buf));
}
}
#[test]
fn test_read_u64_le() {
for &(v, buf) in TEST_U64.iter() {
assert_eq!(v, read_u64_le(buf));
}
}
#[test]
fn test_read_uvar() {
for &(v, buf) in TEST_UVAR.iter() {
assert_eq!((buf.len(), v), read_uvar(buf).unwrap());
}
}
#[test]
fn test_write_u16_le() {
let mut write_buf = [0; 2];
for &(v, buf) in TEST_U16.iter() {
write_u16_le(v, &mut write_buf);
assert_eq!(buf, write_buf);
}
}
#[test]
fn test_write_u32_le() {
let mut write_buf = [0; 4];
for &(v, buf) in TEST_U32.iter() {
write_u32_le(v, &mut write_buf);
assert_eq!(buf, write_buf);
}
}
#[test]
fn test_write_u64_le() {
let mut write_buf = [0; 8];
for &(v, buf) in TEST_U64.iter() {
write_u64_le(v, &mut write_buf);
assert_eq!(buf, write_buf);
}
}
#[test]
fn test_write_uvar() {
let mut write_buf = [0; 10];
for &(v, buf) in TEST_UVAR.iter() {
let written = write_uvar(v, &mut write_buf).unwrap();
assert_eq!(buf, &write_buf[0..written]);
}
}
}

18
src/lib.rs Normal file
View File

@ -0,0 +1,18 @@
// #![feature(core, old_io, old_path, os, collections, unicode)]
// #![allow(unused_variables, dead_code)]
pub mod btree;
pub mod pager;
pub mod pagermemory;
pub mod pagerstream;
pub mod sqlsyntax;
mod byteutils;
pub use self::pager::Pager;
pub use self::pagermemory::PagerMemory;
pub use self::pagerstream::PagerStream;
pub enum SQLError {
}
pub type SQLResult<T> = Result<T, SQLError>;

52
src/pager.rs Normal file
View File

@ -0,0 +1,52 @@
use std::fmt::Debug;
pub const MIN_PAGE_SIZE: usize = 1 << 6;
pub enum PageReference<'a, 'b> {
Immutable(&'a [u8]),
Mutable(&'b mut [u8])
}
pub trait Pager {
type Err: Debug;
fn max_page_size(&self) -> usize;
/// Checks if the page id exists inside the Pager.
///
/// If the page id doesn't exist, an implementation-defined error is returned.
fn check_page_id(&self, page_id: u64) -> Result<(), Self::Err>;
/// Allocates a new page. All bytes in the page are uninitialized.
///
/// The buffer may contain remnants of previous pager operations, so
/// reading said data may make potential bugs in the database more
/// unpredictable and harder to identify.
///
/// Returns a unique page id. The page id can be any value except for 0.
unsafe fn new_page_uninitialized<F>(&mut self, f: F) -> Result<u64, Self::Err>
where F: FnOnce(&mut [u8]);
/// Allocates a new page. All bytes in the page are initialized to zeros.
///
/// Returns a unique page id. The page id can be any value except for 0.
fn new_page<F>(&mut self, f: F) -> Result<u64, Self::Err>
where F: FnOnce(&mut [u8])
{
unsafe {
self.new_page_uninitialized(|buffer| {
for x in buffer.iter_mut() { *x = 0; }
f(buffer);
})
}
}
fn mut_page<F, R>(&mut self, page_id: u64, f: F) -> Result<R, Self::Err>
where F: FnOnce(&mut [u8]) -> R;
fn get_page<'a, 'b>(&'a self, page_id: u64, buffer: &'b mut Vec<u8>) -> Result<PageReference<'a, 'b>, Self::Err>;
fn increment_change_counter(&mut self) -> Result<(), Self::Err>;
fn mark_page_as_removed(&mut self, page_id: u64) -> Result<(), Self::Err>;
}

135
src/pagermemory.rs Normal file
View File

@ -0,0 +1,135 @@
use super::pager;
use super::pager::{PageReference, Pager};
use std::os;
use std::collections::{HashMap, VecDeque};
#[derive(Debug)]
pub enum PagerMemoryErr {
MapError(os::MapError),
PageDoesNotExist(u64)
}
pub type PagerMemoryResult<T> = Result<T, PagerMemoryErr>;
fn map_error(err: os::MapError) -> PagerMemoryErr {
PagerMemoryErr::MapError(err)
}
/// The backing store for the memory pager is `mmap()`, or the operating
/// system's equivilent to `mmap()`.
/// The page size is queried from the operating system to improve memory
/// caching.
pub struct PagerMemory {
max_page_size: usize,
pages: HashMap<u64, os::MemoryMap>,
next_page_id: u64,
unused_pages: VecDeque<os::MemoryMap>
}
impl PagerMemory {
pub fn new() -> PagerMemory {
PagerMemory {
// TODO: query maximum memory map length, somehow.
// For now, it's safe to assume it's probably 4096
max_page_size: 4096,
pages: HashMap::new(),
next_page_id: 0,
unused_pages: VecDeque::new()
}
}
}
impl Pager for PagerMemory {
type Err = PagerMemoryErr;
fn max_page_size(&self) -> usize { self.max_page_size }
fn check_page_id(&self, page_id: u64) -> PagerMemoryResult<()> {
if self.pages.contains_key(&page_id) {
Ok(())
} else {
Err(PagerMemoryErr::PageDoesNotExist(page_id))
}
}
unsafe fn new_page_uninitialized<F>(&mut self, f: F) -> PagerMemoryResult<u64>
where F: FnOnce(&mut [u8])
{
let mut memory_map = match self.unused_pages.pop_back() {
Some(v) => v,
None => {
use std::os::MapOption::*;
let m = try!(os::MemoryMap::new(pager::MIN_PAGE_SIZE, &[MapReadable, MapWritable]).map_err(map_error));
assert!(m.len() <= self.max_page_size);
m
}
};
f(memory_map_as_mut_slice(&mut memory_map));
let page_id = self.next_page_id;
match self.pages.insert(page_id, memory_map) {
None => (),
Some(_) => unreachable!()
}
self.next_page_id += 1;
Ok(page_id)
}
fn mut_page<F, R>(&mut self, page_id: u64, f: F) -> PagerMemoryResult<R>
where F: FnOnce(&mut [u8]) -> R
{
let memory_map = match self.pages.get_mut(&page_id) {
Some(m) => m,
None => return Err(PagerMemoryErr::PageDoesNotExist(page_id))
};
let result = f(memory_map_as_mut_slice(memory_map));
Ok(result)
}
fn get_page<'a, 'b>(&'a self, page_id: u64, buffer: &'b mut Vec<u8>) -> PagerMemoryResult<PageReference<'a, 'b>>
{
let memory_map = match self.pages.get(&page_id) {
Some(m) => m,
None => return Err(PagerMemoryErr::PageDoesNotExist(page_id))
};
let src = memory_map_as_slice(memory_map);
Ok(PageReference::Immutable(src))
}
fn increment_change_counter(&mut self) -> PagerMemoryResult<()> {
// do nothing
Ok(())
}
fn mark_page_as_removed(&mut self, page_id: u64) -> PagerMemoryResult<()> {
match self.pages.remove(&page_id) {
Some(memory_map) => {
self.unused_pages.push_back(memory_map);
Ok(())
},
None => Err(PagerMemoryErr::PageDoesNotExist(page_id))
}
}
}
fn memory_map_as_slice<'a>(memory_map: &'a os::MemoryMap) -> &'a [u8] {
use std::slice;
unsafe {
slice::from_raw_parts(memory_map.data(), memory_map.len())
}
}
fn memory_map_as_mut_slice<'a>(memory_map: &'a mut os::MemoryMap) -> &'a mut [u8] {
use std::slice;
unsafe {
slice::from_raw_parts_mut(memory_map.data(), memory_map.len())
}
}

282
src/pagerstream.rs Normal file
View File

@ -0,0 +1,282 @@
use std::old_io::fs::File;
use std::old_io as io;
use std::old_path::Path;
use std::cell::UnsafeCell;
use super::pager::{PageReference, Pager};
use super::byteutils;
const HEADER_SIZE: usize = 16;
/// Min page size: 2^6 = 64
const MIN_PAGE_SIZE_VALUE: u8 = 6;
/// Max page size: 2^16 = 65536
const MAX_PAGE_SIZE_VALUE: u8 = 16;
#[derive(Debug)]
pub enum PagerStreamErr {
IoError(io::IoError),
TriedToCreateOnNonEmptyStream,
BadPageSize(usize),
BadHeader,
BadStreamSize(u64),
PageDoesNotExist(u64)
}
pub type PagerStreamResult<T> = Result<T, PagerStreamErr>;
fn io_error(err: io::IoError) -> PagerStreamErr {
PagerStreamErr::IoError(err)
}
pub struct PagerStream<S>
where S: io::Reader + io::Writer + io::Seek
{
/// The stream is contained in an UnsafeCell to allow I/O access inside &self methods.
stream: UnsafeCell<S>,
page_size: usize,
page_count: u64,
/// A pre-allocated buffer for `mut_page()`. Keeping it here will avoid expensive reallocations.
/// The buffer is contained in an UnsafeCell to allow usage with the stream, which needs to be simultaneously borrowed.
mut_page_buffer: UnsafeCell<Box<[u8]>>
}
impl<S> PagerStream<S>
where S: io::Reader + io::Writer + io::Seek
{
pub fn open(mut stream: S) -> PagerStreamResult<PagerStream<S>> {
// Get stream size
try!(stream.seek(0, io::SeekStyle::SeekEnd).map_err(io_error));
let stream_size = try!(stream.tell().map_err(io_error));
// Seek to the beginning
try!(stream.seek(0, io::SeekStyle::SeekSet).map_err(io_error));
// Read header
let mut header_bytes = [0; HEADER_SIZE];
try!(stream.read_at_least(HEADER_SIZE, &mut header_bytes).map_err(io_error));
let header = try!(Header::parse(&header_bytes));
let page_count = match (stream_size / header.page_size as u64, stream_size % header.page_size as u64) {
(page_count, 0) => page_count,
(0, _) | (_, _) => return Err(PagerStreamErr::BadStreamSize(stream_size)),
};
Ok(PagerStream {
stream: UnsafeCell::new(stream),
page_size: header.page_size,
page_count: page_count,
mut_page_buffer: UnsafeCell::new(vec![0; header.page_size].into_boxed_slice())
})
}
pub fn create(mut stream: S, page_size: usize) -> PagerStreamResult<PagerStream<S>> {
// Get stream size
try!(stream.seek(0, io::SeekStyle::SeekEnd).map_err(io_error));
let stream_size = try!(stream.tell().map_err(io_error));
if stream_size != 0 {
return Err(PagerStreamErr::TriedToCreateOnNonEmptyStream);
}
let header = Header {
page_size: page_size,
change_counter: 0,
freelist_head: 0
};
let mut header_bytes = [0; HEADER_SIZE];
try!(header.serialize(&mut header_bytes));
// Write the header
try!(stream.write_all(&header_bytes).map_err(io_error));
// Pad the rest of the page with zeros
let padding = vec![0; page_size - header_bytes.len()];
try!(stream.write_all(padding.as_slice()).map_err(io_error));
Ok(PagerStream {
stream: UnsafeCell::new(stream),
page_size: page_size,
page_count: 1,
mut_page_buffer: UnsafeCell::new(vec![0; page_size].into_boxed_slice())
})
}
fn stream(&self) -> &mut S {
use std::mem;
unsafe { mem::transmute(self.stream.get()) }
}
fn mut_page_buffer(&self) -> &mut [u8] {
use std::mem;
let buffer_box: &mut Box<[u8]> = unsafe { mem::transmute(self.mut_page_buffer.get()) };
&mut *buffer_box
}
}
pub fn open_from_path(path: &Path) -> PagerStreamResult<PagerStream<File>> {
use std::old_io::{FileAccess, FileMode};
let file = try!(File::open_mode(path, FileMode::Append, FileAccess::ReadWrite).map_err(io_error));
PagerStream::open(file)
}
pub fn create_from_path(path: &Path, page_size: usize) -> PagerStreamResult<PagerStream<File>> {
use std::old_io::{FileAccess, FileMode};
let file = try!(File::open_mode(path, FileMode::Truncate, FileAccess::ReadWrite).map_err(io_error));
PagerStream::create(file, page_size)
}
struct Header {
page_size: usize,
change_counter: u64,
freelist_head: u64
}
impl Header {
fn parse(header: &[u8; HEADER_SIZE]) -> PagerStreamResult<Header> {
use std::cmp::Ord;
fn check_range<T>(value: T, min: T, max: T) -> PagerStreamResult<T>
where T: Ord
{
if (value <= max) && (value >= min) { Ok(value) }
else { Err(PagerStreamErr::BadHeader) }
}
let page_size: usize = 1 << try!(check_range(header[0], MIN_PAGE_SIZE_VALUE, MAX_PAGE_SIZE_VALUE));
let change_counter: u64 = byteutils::read_u64_le(&header[1..9]);
let freelist_head: u64 = byteutils::read_u64_le(&header[9..17]);
Ok(Header {
page_size: page_size,
change_counter: change_counter,
freelist_head: freelist_head
})
}
fn serialize(&self, buffer: &mut [u8; HEADER_SIZE]) -> PagerStreamResult<()> {
use std::num::Int;
for x in buffer.iter_mut() { *x = 0; }
if self.page_size.count_ones() != 1 {
return Err(PagerStreamErr::BadPageSize(self.page_size));
}
let page_size_shl = self.page_size >> self.page_size.trailing_zeros();
buffer[0] = page_size_shl as u8;
byteutils::write_u64_le(self.change_counter, &mut buffer[1..9]);
byteutils::write_u64_le(self.freelist_head, &mut buffer[9..17]);
Ok(())
}
}
impl<S> Pager for PagerStream<S>
where S: io::Reader + io::Writer + io::Seek
{
type Err = PagerStreamErr;
fn max_page_size(&self) -> usize { self.page_size }
fn check_page_id(&self, page_id: u64) -> PagerStreamResult<()> {
if page_id == 0 || page_id >= self.page_count {
Err(PagerStreamErr::PageDoesNotExist(page_id))
} else {
Ok(())
}
}
unsafe fn new_page_uninitialized<F>(&mut self, f: F) -> PagerStreamResult<u64>
where F: FnOnce(&mut [u8])
{
let page_id = self.page_count;
{
let stream = self.stream();
let buffer = self.mut_page_buffer();
try!(stream.seek(0, io::SeekStyle::SeekEnd).map_err(io_error));
f(buffer);
// Write the new page
try!(stream.write_all(buffer.as_slice()).map_err(io_error));
}
self.page_count += 1;
Ok(page_id)
}
fn mut_page<F, R>(&mut self, page_id: u64, f: F) -> PagerStreamResult<R>
where F: FnOnce(&mut [u8]) -> R
{
try!(self.check_page_id(page_id));
let stream = self.stream();
let buffer = self.mut_page_buffer();
// Seek to the requested page
let page_offset: u64 = page_id * self.page_size as u64;
try!(stream.seek(page_offset as i64, io::SeekStyle::SeekSet).map_err(io_error));
try!(stream.read_at_least(self.page_size, buffer).map_err(io_error));
// Mutate the page buffer
let result = f(buffer);
// Write the mutated page back
try!(stream.seek(page_offset as i64, io::SeekStyle::SeekSet).map_err(io_error));
try!(stream.write_all(buffer).map_err(io_error));
Ok(result)
}
fn get_page<'a, 'b>(&'a self, page_id: u64, buffer: &'b mut Vec<u8>) -> PagerStreamResult<PageReference<'a, 'b>>
{
use std::slice;
try!(self.check_page_id(page_id));
let stream = self.stream();
// Seek to the requested page
let page_offset: u64 = page_id * self.page_size as u64;
try!(stream.seek(page_offset as i64, io::SeekStyle::SeekSet).map_err(io_error));
// Ensure the buffer has enough capacity to store `self.page_size` contiguous bytes.
if self.page_size > buffer.capacity() {
let reserve = self.page_size - buffer.capacity();
buffer.reserve(reserve);
}
unsafe {
// Set the buffer length to 0, in case the I/O operations fail.
// If I/O fails, the buffer will appear empty to the caller.
buffer.set_len(0);
let buffer_ptr = buffer.as_mut_slice().as_mut_ptr();
let buffer_slice = slice::from_raw_parts_mut(buffer_ptr, self.page_size);
try!(stream.read_at_least(self.page_size, buffer_slice).map_err(io_error));
buffer.set_len(self.page_size);
}
Ok(PageReference::Mutable(buffer.as_mut_slice()))
}
fn increment_change_counter(&mut self) -> PagerStreamResult<()> {
self.mut_page(0, |buffer| {
let old_change_counter = byteutils::read_u64_le(&buffer[1..9]);
let new_change_counter = old_change_counter + 1;
byteutils::write_u64_le(new_change_counter, &mut buffer[1..9]);
})
}
fn mark_page_as_removed(&mut self, page_id: u64) -> PagerStreamResult<()> {
// TODO: implement
Ok(())
}
}

146
src/sqlsyntax/ast.rs Normal file
View File

@ -0,0 +1,146 @@
#[derive(Debug, PartialEq)]
pub enum UnaryOp {
Negate
}
#[derive(Debug, PartialEq)]
pub enum BinaryOp {
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
And,
Or,
Add,
Subtract,
Multiply,
BitAnd,
BitOr,
Concatenate,
}
#[derive(Debug, PartialEq)]
pub enum Expression {
Ident(String),
StringLiteral(String),
Number(String),
/// name(argument1, argument2, argument3...)
FunctionCall { name: String, arguments: Vec<Expression> },
/// name(*)
FunctionCallAggregateAll { name: String },
UnaryOp {
expr: Box<Expression>,
op: UnaryOp
},
/// lhs op rhs
BinaryOp {
lhs: Box<Expression>,
rhs: Box<Expression>,
op: BinaryOp
}
}
#[derive(Debug)]
pub struct Table {
pub database_name: Option<String>,
pub table_name: String
}
#[derive(Debug)]
pub enum TableOrSubquery {
Subquery {
subquery: Box<SelectStatement>,
alias: Option<String>
},
Table {
table: Table,
alias: Option<String>
}
}
#[derive(Debug, PartialEq)]
pub enum SelectColumn {
AllColumns,
Expr {
expr: Expression,
alias: Option<String>
}
}
#[derive(Debug)]
pub struct SelectStatement {
pub result_columns: Vec<SelectColumn>,
pub from: From,
pub where_expr: Option<Expression>,
pub group_by: Vec<Expression>,
pub having: Option<Expression>
}
#[derive(Debug)]
pub enum From {
Cross(Vec<TableOrSubquery>),
// TODO: add joins
}
#[derive(Debug)]
pub struct InsertStatement {
pub table: Table,
pub into_columns: Option<Vec<String>>,
pub source: InsertSource
}
#[derive(Debug)]
pub enum InsertSource {
Values(Vec<Vec<Expression>>),
Select(Box<SelectStatement>)
}
#[derive(Debug)]
pub struct CreateTableColumnConstraint {
pub name: Option<String>,
pub constraint: CreateTableColumnConstraintType
}
#[derive(Debug)]
pub enum CreateTableColumnConstraintType {
PrimaryKey,
Unique,
Nullable,
ForeignKey {
table: Table,
columns: Option<Vec<String>>
}
}
#[derive(Debug)]
pub struct CreateTableColumn {
pub column_name: String,
pub type_name: String,
pub type_size: Option<String>,
/// * None if no array
/// * Some(None) if dynamic array: type[]
/// * Some(Some(_)) if fixed array: type[SIZE]
pub type_array_size: Option<Option<String>>,
pub constraints: Vec<CreateTableColumnConstraint>
}
#[derive(Debug)]
pub struct CreateTableStatement {
pub table: Table,
pub columns: Vec<CreateTableColumn>
}
#[derive(Debug)]
pub enum CreateStatement {
Table(CreateTableStatement)
}
#[derive(Debug)]
pub enum Statement {
Select(SelectStatement),
Insert(InsertStatement),
Create(CreateStatement)
}

423
src/sqlsyntax/lexer.rs Normal file
View File

@ -0,0 +1,423 @@
/// Disclaimer: The lexer is basically spaghetti. What did you expect?
#[derive(Clone, Debug, PartialEq)]
pub enum Token {
// Words
Select, From, Where, Group, Having, By, Limit,
Distinct,
Order, Asc, Desc,
As, Join, Inner, Outer, Left, Right, On,
Insert, Into, Values, Update, Delete,
Create, Table, Index, Constraint,
Primary, Key, Unique, References,
And, Or,
Between, In,
Is, Not, Null,
// Non-letter tokens
Equal,
NotEqual,
LessThan, LessThanOrEqual,
GreaterThan, GreaterThanOrEqual,
Plus, Minus,
LeftParen, RightParen,
LeftBracket, RightBracket,
Dot, Comma, Semicolon,
Ampersand, Pipe,
/// ||, the concatenate operator
DoublePipe,
/// *, the wildcard for SELECT
Asterisk,
/// ?, the prepared statement placeholder
PreparedStatementPlaceholder,
// Tokens with values
Number(String),
Ident(String),
StringLiteral(String)
}
fn character_to_token(c: char) -> Option<Token> {
use self::Token::*;
Some(match c {
'=' => Equal,
'<' => LessThan,
'>' => GreaterThan,
'+' => Plus,
'-' => Minus,
'(' => LeftParen,
')' => RightParen,
'[' => LeftBracket,
']' => RightBracket,
'.' => Dot,
',' => Comma,
';' => Semicolon,
'&' => Ampersand,
'|' => Pipe,
'*' => Asterisk,
'?' => PreparedStatementPlaceholder,
_ => return None
})
}
fn word_to_token(word: String) -> Token {
use self::Token::*;
// Make all letters lowercase for comparison
let word_cmp: String = word.chars().map( |c| c.to_lowercase() ).collect();
match word_cmp.as_slice() {
"select" => Select,
"from" => From,
"where" => Where,
"group" => Group,
"having" => Having,
"by" => By,
"limit" => Limit,
"distinct" => Distinct,
"order" => Order,
"asc" => Asc,
"desc" => Desc,
"as" => As,
"join" => Join,
"inner" => Inner,
"outer" => Outer,
"left" => Left,
"right" => Right,
"on" => On,
"insert" => Insert,
"into" => Into,
"values" => Values,
"update" => Update,
"delete" => Delete,
"create" => Create,
"table" => Table,
"index" => Index,
"constraint" => Constraint,
"primary" => Primary,
"key" => Key,
"unique" => Unique,
"references" => References,
"and" => And,
"or" => Or,
"between" => Between,
"in" => In,
"is" => Is,
"not" => Not,
"null" => Null,
_ => Ident(word)
}
}
enum LexerState {
NoState,
Word,
Backtick,
Apostrophe { escaping: bool },
Number { decimal: bool },
/// Disambiguate an operator sequence.
OperatorDisambiguate { first: char },
LineComment,
}
struct Lexer {
state: LexerState,
string_buffer: String,
tokens: Vec<Token>
}
impl Lexer {
fn no_state(&mut self, c: char) -> Result<LexerState, char> {
match c {
'a'...'z' | 'A'...'Z' | '_' => {
self.string_buffer.push(c);
Ok(LexerState::Word)
},
'`' => {
Ok(LexerState::Backtick)
}
'\'' => {
// string literal
Ok(LexerState::Apostrophe { escaping: false })
},
'0'...'9' => {
self.string_buffer.push(c);
Ok(LexerState::Number { decimal: false })
},
' ' | '\t' | '\n' => {
// whitespace
Ok(LexerState::NoState)
},
c => {
use self::Token::*;
match character_to_token(c) {
Some(LessThan) | Some(GreaterThan) | Some(Minus) | Some(Pipe) => {
Ok(LexerState::OperatorDisambiguate { first: c })
},
Some(token) => {
self.tokens.push(token);
Ok(LexerState::NoState)
},
None => {
// unknown character
Err(c)
}
}
}
}
}
fn move_string_buffer(&mut self) -> String {
use std::mem;
mem::replace(&mut self.string_buffer, String::new())
}
pub fn feed_character(&mut self, c: Option<char>) {
self.state = match self.state {
LexerState::NoState => {
match c {
Some(c) => self.no_state(c).unwrap(),
None => LexerState::NoState
}
},
LexerState::Word => {
match c {
Some(c) => match c {
'a'...'z' | 'A'...'Z' | '_' | '0'...'9' => {
self.string_buffer.push(c);
LexerState::Word
}
c => {
let buffer = self.move_string_buffer();
self.tokens.push(word_to_token(buffer));
self.no_state(c).unwrap()
}
},
None => {
let buffer = self.move_string_buffer();
self.tokens.push(word_to_token(buffer));
LexerState::NoState
}
}
},
LexerState::Backtick => {
match c {
Some('`') => {
let buffer = self.move_string_buffer();
self.tokens.push(Token::Ident(buffer));
LexerState::NoState
},
Some(c) => {
self.string_buffer.push(c);
LexerState::Backtick
},
None => {
// error: backtick did not finish
unimplemented!()
}
}
},
LexerState::Apostrophe { escaping } => {
if let Some(c) = c {
match (escaping, c) {
(false, '\'') => {
// unescaped apostrophe
let buffer = self.move_string_buffer();
self.tokens.push(Token::StringLiteral(buffer));
LexerState::NoState
},
(false, '\\') => {
// unescaped backslash
LexerState::Apostrophe { escaping: true }
},
(true, _) | _ => {
self.string_buffer.push(c);
LexerState::Apostrophe { escaping: false }
}
}
} else {
// error: apostrophe did not finish
unimplemented!()
}
},
LexerState::Number { decimal } => {
if let Some(c) = c {
match c {
'0'...'9' => {
self.string_buffer.push(c);
LexerState::Number { decimal: decimal }
},
'.' if !decimal => {
// Add a decimal point. None has been added yet.
self.string_buffer.push(c);
LexerState::Number { decimal: true }
},
c => {
let buffer = self.move_string_buffer();
self.tokens.push(Token::Number(buffer));
self.no_state(c).unwrap()
}
}
} else {
let buffer = self.move_string_buffer();
self.tokens.push(Token::Number(buffer));
LexerState::NoState
}
},
LexerState::OperatorDisambiguate { first } => {
use self::Token::*;
if let Some(c) = c {
match (first, c) {
('<', '>') => {
self.tokens.push(NotEqual);
LexerState::NoState
},
('<', '=') => {
self.tokens.push(LessThanOrEqual);
LexerState::NoState
},
('>', '=') => {
self.tokens.push(GreaterThanOrEqual);
LexerState::NoState
},
('|', '|') => {
self.tokens.push(DoublePipe);
LexerState::NoState
},
('-', '-') => {
LexerState::LineComment
},
_ => {
self.tokens.push(character_to_token(first).unwrap());
self.no_state(c).unwrap()
}
}
} else {
self.tokens.push(character_to_token(first).unwrap());
LexerState::NoState
}
},
LexerState::LineComment => {
match c {
Some('\n') => LexerState::NoState,
_ => LexerState::LineComment
}
}
};
}
}
pub fn parse(sql: &str) -> Vec<Token> {
let mut lexer = Lexer {
state: LexerState::NoState,
string_buffer: String::new(),
tokens: Vec::new()
};
for c in sql.chars() {
lexer.feed_character(Some(c));
}
lexer.feed_character(None);
lexer.tokens
}
#[cfg(test)]
mod test {
use super::parse;
fn id(value: &str) -> super::Token {
super::Token::Ident(value.to_string())
}
fn number(value: &str) -> super::Token {
super::Token::Number(value.to_string())
}
#[test]
fn test_sql_lexer_dontconfuseidentswithkeywords() {
use super::Token::*;
// Not: AS, Ident("df")
assert_eq!(parse("asdf"), vec![Ident("asdf".to_string())]);
}
#[test]
fn test_sql_lexer_escape() {
use super::Token::*;
// Escaped apostrophe
assert_eq!(parse(r"'\''"), vec![StringLiteral("'".to_string())]);
}
#[test]
fn test_sql_lexer_numbers() {
use super::Token::*;
assert_eq!(parse("12345"), vec![number("12345")]);
assert_eq!(parse("0.25"), vec![number("0.25")]);
assert_eq!(parse("0.25 + -0.25"), vec![number("0.25"), Plus, Minus, number("0.25")]);
assert_eq!(parse("-0.25 + 0.25"), vec![Minus, number("0.25"), Plus, number("0.25")]);
assert_eq!(parse("- 0.25 - -0.25"), vec![Minus, number("0.25"), Minus, Minus, number("0.25")]);
assert_eq!(parse("- 0.25 --0.25"), vec![Minus, number("0.25")]);
assert_eq!(parse("0.25 -0.25"), vec![number("0.25"), Minus, number("0.25")]);
}
#[test]
fn test_sql_lexer_query1() {
use super::Token::*;
assert_eq!(parse(" SeLECT a, b as alias1 , c alias2, d ` alias three ` fRoM table1 WHERE a='Hello World'; "),
vec![
Select, id("a"), Comma, id("b"), As, id("alias1"), Comma,
id("c"), id("alias2"), Comma, id("d"), id(" alias three "),
From, id("table1"),
Where, id("a"), Equal, StringLiteral("Hello World".to_string()), Semicolon
]
);
}
#[test]
fn test_sql_lexer_query2() {
use super::Token::*;
let query = r"
-- Get employee count from each department
SELECT d.id departmentId, count(e.id) employeeCount
FROM department d
LEFT JOIN employee e ON e.departmentId = d.id
GROUP BY departmentId;
";
assert_eq!(parse(query), vec![
Select, id("d"), Dot, id("id"), id("departmentId"), Comma, id("count"), LeftParen, id("e"), Dot, id("id"), RightParen, id("employeeCount"),
From, id("department"), id("d"),
Left, Join, id("employee"), id("e"), On, id("e"), Dot, id("departmentId"), Equal, id("d"), Dot, id("id"),
Group, By, id("departmentId"), Semicolon
]);
}
#[test]
fn test_sql_lexer_operators() {
use super::Token::*;
assert_eq!(parse("> = >=< =><"),
vec![
GreaterThan, Equal, GreaterThanOrEqual, LessThan, Equal, GreaterThan, LessThan
]
);
assert_eq!(parse(" ><>> >< >"),
vec![
GreaterThan, NotEqual, GreaterThan, GreaterThan, LessThan, GreaterThan
]
);
}
}

31
src/sqlsyntax/mod.rs Normal file
View File

@ -0,0 +1,31 @@
/// As of writing, there aren't any good or stable LALR(1) parser generators for Rust.
/// As a consequence, the lexer and parser are both written by hand.
pub mod ast;
mod lexer;
mod parser;
pub fn parse(query: &str) -> Vec<ast::Statement> {
let tokens = lexer::parse(query);
parser::parse(tokens.as_slice()).unwrap()
}
#[cfg(test)]
mod test {
use super::parse;
#[test]
fn test_sql_parser() {
parse("SELECT *, (name + 4), count(*) AS amount FROM (SELECT * FROM foo), table1 GROUP BY name HAVING count(*) > 5;");
parse("INSERT INTO table1 VALUES (1, 2), (3, 4), (5, 6);");
parse("INSERT INTO table1 (a, b) VALUES ('foo' || 'bar', 2);");
parse("INSERT INTO table1 SELECT * FROM foo;");
parse("CREATE TABLE test (
foo INT CONSTRAINT pk PRIMARY KEY,
bar VARCHAR(256),
data BYTE[32] NULL UNIQUE
);");
}
}

601
src/sqlsyntax/parser/mod.rs Normal file
View File

@ -0,0 +1,601 @@
/// The parser is a recursive descent parser.
use std::marker::{PhantomData, Sized};
use std::fmt;
use super::lexer::Token;
use super::ast::*;
mod tokens;
use self::tokens::Tokens;
pub enum RuleError {
ExpectingFirst(&'static str, Option<Token>),
Expecting(&'static str, Option<Token>),
NoMoreTokens
}
impl fmt::Display for RuleError {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use self::RuleError::*;
match self {
&ExpectingFirst(s, Some(ref token)) => write!(f, "Expecting {}; got {:?}", s, token),
&Expecting(s, Some(ref token)) => write!(f, "Expecting {}; got {:?}", s, token),
&ExpectingFirst(s, None) => write!(f, "Expecting {}; got no more tokens", s),
&Expecting(s, None) => write!(f, "Expecting {}; got no more tokens", s),
&NoMoreTokens => write!(f, "No more tokens")
}
}
}
impl fmt::Debug for RuleError {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "{}", self)
}
}
type RuleResult<T> = Result<T, RuleError>;
fn rule_result_not_first<T>(rule_result: RuleResult<T>) -> RuleResult<T> {
use self::RuleError::*;
match rule_result {
Err(ExpectingFirst(s, t)) => Err(Expecting(s, t)),
value => value
}
}
macro_rules! try_notfirst {
($r:expr) => {
try!(rule_result_not_first($r))
}
}
trait Rule: Sized {
type Output: Sized = Self;
fn parse(tokens: &mut Tokens) -> RuleResult<Self::Output>;
}
trait RuleExt: Rule {
/// Attempts to parse a rule. If the rule is "wrong", None is returned.
/// The parser will backtrack if the rule doesn't match or there's an error.
///
/// This parses a rule with a lookahead of 1.
/// If the error from parse is ExpectingFirst, it's converted to None.
/// All other errors are unmodified.
fn parse_lookahead<'a>(tokens: &mut Tokens<'a>) -> RuleResult<Option<Self::Output>> {
let mut tokens_copy: Tokens<'a> = *tokens;
match Self::parse(&mut tokens_copy) {
Ok(v) => {
*tokens = tokens_copy;
Ok(Some(v))
},
Err(RuleError::ExpectingFirst(..)) => {
Ok(None)
},
Err(e) => Err(e)
}
}
fn parse_comma_delimited(tokens: &mut Tokens) -> RuleResult<Vec<Self::Output>> {
CommaDelimitedRule::<Self>::parse(tokens)
}
fn parse_series<'a>(tokens: &mut Tokens<'a>) -> RuleResult<Vec<Self::Output>> {
let mut v = Vec::new();
while let Some(value) = try!(Self::parse_lookahead(tokens)) {
v.push(value);
}
Ok(v)
}
}
struct CommaDelimitedRule<R: Rule> {
_marker: PhantomData<R>
}
impl<R: Rule> Rule for CommaDelimitedRule<R> {
type Output = Vec<R::Output>;
fn parse(tokens: &mut Tokens) -> RuleResult<Vec<R::Output>> {
let mut v = Vec::new();
let value = try!(R::parse(tokens));
v.push(value);
// loop until no comma
while tokens.pop_if_token(&Token::Comma) {
// After the first item, ExpectingFirst gets converted to Expecting.
let value = try!(rule_result_not_first(R::parse(tokens)));
v.push(value);
}
Ok(v)
}
}
struct ParensSurroundRule<R: Rule> {
_marker: PhantomData<R>
}
impl<R: Rule> Rule for ParensSurroundRule<R> {
type Output = R::Output;
fn parse(tokens: &mut Tokens) -> RuleResult<R::Output> {
try!(tokens.pop_expecting(&Token::LeftParen, "("));
let p = try_notfirst!(R::parse(tokens));
try_notfirst!(tokens.pop_expecting(&Token::RightParen, ")"));
Ok(p)
}
}
/// (R,R,R,...)
type ParensCommaDelimitedRule<R> = ParensSurroundRule<CommaDelimitedRule<R>>;
impl<R> RuleExt for R where R: Rule {}
struct Ident;
impl Rule for Ident {
type Output = String;
fn parse(tokens: &mut Tokens) -> RuleResult<String> {
tokens.pop_ident_expecting("identifier")
}
}
impl Rule for BinaryOp {
type Output = BinaryOp;
fn parse(tokens: &mut Tokens) -> RuleResult<BinaryOp> {
match try!(tokens.pop()) {
&Token::Equal => Ok(BinaryOp::Equal),
&Token::NotEqual => Ok(BinaryOp::NotEqual),
&Token::LessThan => Ok(BinaryOp::LessThan),
&Token::LessThanOrEqual => Ok(BinaryOp::LessThanOrEqual),
&Token::GreaterThan => Ok(BinaryOp::GreaterThan),
&Token::GreaterThanOrEqual => Ok(BinaryOp::GreaterThanOrEqual),
&Token::And => Ok(BinaryOp::And),
&Token::Or => Ok(BinaryOp::Or),
&Token::Plus => Ok(BinaryOp::Add),
&Token::Minus => Ok(BinaryOp::Subtract),
&Token::Asterisk => Ok(BinaryOp::Multiply),
&Token::Ampersand => Ok(BinaryOp::BitAnd),
&Token::Pipe => Ok(BinaryOp::BitOr),
&Token::DoublePipe => Ok(BinaryOp::Concatenate),
_ => Err(tokens.expecting("binary operator"))
}
}
}
impl UnaryOp {
fn precedence(&self) -> u8 {
use super::ast::UnaryOp::*;
match self {
&Negate => 6
}
}
}
impl BinaryOp {
/// Operators with a higher precedence have a higher number.
fn precedence(&self) -> u8 {
use super::ast::BinaryOp::*;
match self {
&Multiply => 5,
&Add | &Subtract | &BitAnd | &BitOr | &Concatenate => 4,
// comparison
&Equal | &NotEqual | &LessThan | &LessThanOrEqual | &GreaterThan | &GreaterThanOrEqual => 3,
// conjugation
&And => 2,
&Or => 1,
}
}
}
impl Rule for Expression {
type Output = Expression;
fn parse(tokens: &mut Tokens) -> RuleResult<Expression> {
Expression::parse_precedence(tokens, 0)
}
}
impl Expression {
/// Expressions are parsed using an algorithm known as "precedence climbing".
///
/// Precedence can be tricky to implement with recursive descent parsers,
/// so this is simple a method that doesn't involve creating different
/// rules for different precedence levels.
fn parse_precedence(tokens: &mut Tokens, min_precedence: u8) -> RuleResult<Expression> {
let mut expr = try!(Expression::parse_beginning(tokens));
let mut prev_tokens = *tokens;
// Test for after-expression tokens
while let Some(binary_op) = try_notfirst!(BinaryOp::parse_lookahead(tokens)) {
let binary_op_precedence = binary_op.precedence();
if binary_op_precedence >= min_precedence {
// Assuming left associative
let q = binary_op_precedence + 1;
let rhs = try_notfirst!(Expression::parse_precedence(tokens, q));
let new_expr = Expression::BinaryOp {
lhs: Box::new(expr),
rhs: Box::new(rhs),
op: binary_op
};
expr = new_expr;
prev_tokens = *tokens;
} else {
// Backtrack if the precedence is lower
*tokens = prev_tokens;
// Let the previous expression rule with the lower precedence (if any) take over
break;
}
}
Ok(expr)
}
fn parse_beginning(tokens: &mut Tokens) -> RuleResult<Expression> {
if tokens.pop_if_token(&Token::Plus) {
// Unary, positive
// There's no point in making a Positive unary operator, so we'll "cheat" and use negate's precedence.
Ok(try_notfirst!(Expression::parse_precedence(tokens, UnaryOp::Negate.precedence())))
} else if tokens.pop_if_token(&Token::Minus) {
// Unary, negation
let e = try_notfirst!(Expression::parse_precedence(tokens, UnaryOp::Negate.precedence()));
Ok(Expression::UnaryOp {
expr: Box::new(e),
op: UnaryOp::Negate
})
} else if let Some(encased_expression) = try!(ParensSurroundRule::<Expression>::parse_lookahead(tokens)) {
// Expression is surrounded in parens for precedence.
Ok(encased_expression)
} else if let Some(ident) = tokens.pop_if_ident() {
if tokens.pop_if_token(&Token::LeftParen) {
// Function call
if tokens.pop_if_token(&Token::Asterisk) {
try_notfirst!(tokens.pop_expecting(&Token::RightParen, ") after aggregate asterisk. e.g. (*)"));
Ok(Expression::FunctionCallAggregateAll { name: ident })
} else {
let arguments = try_notfirst!(Expression::parse_comma_delimited(tokens));
try_notfirst!(tokens.pop_expecting(&Token::RightParen, ") after function arguments"));
Ok(Expression::FunctionCall { name: ident, arguments: arguments })
}
} else if tokens.pop_if_token(&Token::Dot) {
// Member access
unimplemented!()
} else {
Ok(Expression::Ident(ident))
}
} else if let Some(string) = tokens.pop_if_string_literal() {
Ok(Expression::StringLiteral(string))
} else if let Some(number) = tokens.pop_if_number() {
Ok(Expression::Number(number))
} else {
Err(tokens.expecting("identifier or number"))
}
}
}
struct AsAlias;
impl Rule for AsAlias {
type Output = String;
fn parse(tokens: &mut Tokens) -> RuleResult<String> {
if tokens.pop_if_token(&Token::As) {
// Expecting alias
Ok(try_notfirst!(tokens.pop_ident_expecting("alias after `as` keyword")))
} else {
tokens.pop_ident_expecting("alias name or `as` keyword")
}
}
}
impl Rule for Table {
type Output = Table;
fn parse(tokens: &mut Tokens) -> RuleResult<Table> {
let table_name = try!(tokens.pop_ident_expecting("table name"));
Ok(Table {
database_name: None,
table_name: table_name
})
}
}
impl Rule for TableOrSubquery {
type Output = TableOrSubquery;
fn parse(tokens: &mut Tokens) -> RuleResult<TableOrSubquery> {
if let Some(select) = try!(ParensSurroundRule::<SelectStatement>::parse_lookahead(tokens)) {
// Subquery
let alias = try_notfirst!(AsAlias::parse_lookahead(tokens));
Ok(TableOrSubquery::Subquery {
subquery: Box::new(select),
alias: alias
})
} else if let Some(table) = try!(Table::parse_lookahead(tokens)) {
// Table
let alias = try_notfirst!(AsAlias::parse_lookahead(tokens));
Ok(TableOrSubquery::Table {
table: table,
alias: alias
})
} else {
Err(tokens.expecting("subquery or table name"))
}
}
}
impl Rule for SelectColumn {
type Output = SelectColumn;
fn parse(tokens: &mut Tokens) -> RuleResult<SelectColumn> {
if tokens.pop_if_token(&Token::Asterisk) {
Ok(SelectColumn::AllColumns)
} else if let Some(expr) = try!(Expression::parse_lookahead(tokens)) {
let alias = try_notfirst!(AsAlias::parse_lookahead(tokens));
Ok(SelectColumn::Expr {
expr: expr,
alias: alias
})
} else {
Err(tokens.expecting("* or expression for SELECT column"))
}
}
}
impl Rule for SelectStatement {
type Output = SelectStatement;
fn parse(tokens: &mut Tokens) -> RuleResult<SelectStatement> {
try!(tokens.pop_expecting(&Token::Select, "SELECT"));
let result_columns: Vec<SelectColumn> = try_notfirst!(SelectColumn::parse_comma_delimited(tokens));
let from = try_notfirst!(From::parse(tokens));
let where_expr = if tokens.pop_if_token(&Token::Where) {
Some(try_notfirst!(Expression::parse(tokens)))
} else {
None
};
let (group_by, having) = if tokens.pop_if_token(&Token::Group) {
try_notfirst!(tokens.pop_expecting(&Token::By, "BY after GROUP"));
let group_exprs = try_notfirst!(Expression::parse_comma_delimited(tokens));
if tokens.pop_if_token(&Token::Having) {
let having_expr = try_notfirst!(Expression::parse(tokens));
(group_exprs, Some(having_expr))
} else {
(group_exprs, None)
}
} else {
(Vec::new(), None)
};
Ok(SelectStatement {
result_columns: result_columns,
from: from,
where_expr: where_expr,
group_by: group_by,
having: having
})
}
}
impl Rule for From {
type Output = From;
fn parse(tokens: &mut Tokens) -> RuleResult<From> {
try!(tokens.pop_expecting(&Token::From, "FROM"));
let tables = try_notfirst!(TableOrSubquery::parse_comma_delimited(tokens));
Ok(From::Cross(tables))
}
}
impl Rule for InsertStatement {
type Output = InsertStatement;
fn parse(tokens: &mut Tokens) -> RuleResult<InsertStatement> {
try!(tokens.pop_expecting(&Token::Insert, "INSERT"));
try_notfirst!(tokens.pop_expecting(&Token::Into, "INTO"));
let table = try_notfirst!(Table::parse(tokens));
let into_columns = try_notfirst!(ParensCommaDelimitedRule::<Ident>::parse_lookahead(tokens));
let source = try_notfirst!(InsertSource::parse(tokens));
Ok(InsertStatement {
table: table,
into_columns: into_columns,
source: source
})
}
}
impl Rule for InsertSource {
type Output = InsertSource;
fn parse(tokens: &mut Tokens) -> RuleResult<InsertSource> {
if tokens.pop_if_token(&Token::Values) {
let values = try_notfirst!(CommaDelimitedRule::<ParensCommaDelimitedRule<Expression>>::parse(tokens));
Ok(InsertSource::Values(values))
} else if let Some(select) = try!(SelectStatement::parse_lookahead(tokens)) {
Ok(InsertSource::Select(Box::new(select)))
} else {
Err(tokens.expecting("VALUES or SELECT"))
}
}
}
impl Rule for CreateTableColumnConstraint {
type Output = CreateTableColumnConstraint;
fn parse(tokens: &mut Tokens) -> RuleResult<CreateTableColumnConstraint> {
if tokens.pop_if_token(&Token::Constraint) {
let name = try_notfirst!(tokens.pop_ident_expecting("constraint name after CONSTRAINT"));
let constraint = try_notfirst!(CreateTableColumnConstraintType::parse(tokens));
Ok(CreateTableColumnConstraint {
name: Some(name),
constraint: constraint
})
} else {
let constraint = try!(CreateTableColumnConstraintType::parse(tokens));
Ok(CreateTableColumnConstraint {
name: None,
constraint: constraint
})
}
}
}
impl Rule for CreateTableColumnConstraintType {
type Output = CreateTableColumnConstraintType;
fn parse(tokens: &mut Tokens) -> RuleResult<CreateTableColumnConstraintType> {
use super::ast::CreateTableColumnConstraintType::*;
if tokens.pop_if_token(&Token::Primary) {
try_notfirst!(tokens.pop_expecting(&Token::Key, "KEY after PRIMARY"));
Ok(PrimaryKey)
} else if tokens.pop_if_token(&Token::Unique) {
Ok(Unique)
} else if tokens.pop_if_token(&Token::Null) {
Ok(Nullable)
} else if tokens.pop_if_token(&Token::References) {
let table = try_notfirst!(Table::parse(tokens));
let columns = try_notfirst!(ParensCommaDelimitedRule::<Ident>::parse_lookahead(tokens));
Ok(ForeignKey {
table: table,
columns: columns
})
} else {
Err(tokens.expecting("column constraint"))
}
}
}
impl Rule for CreateTableColumn {
type Output = CreateTableColumn;
fn parse(tokens: &mut Tokens) -> RuleResult<CreateTableColumn> {
let column_name = try!(tokens.pop_ident_expecting("column name"));
let type_name = try_notfirst!(tokens.pop_ident_expecting("type name"));
let type_size = if tokens.pop_if_token(&Token::LeftParen) {
let x = try!(tokens.pop_number_expecting("column type size"));
try!(tokens.pop_expecting(&Token::RightParen, ")"));
Some(x)
} else {
None
};
let type_array_size = if tokens.pop_if_token(&Token::LeftBracket) {
if tokens.pop_if_token(&Token::RightBracket) {
// Dynamic array
Some(None)
} else {
let x = try!(tokens.pop_number_expecting("column array size"));
try!(tokens.pop_expecting(&Token::RightBracket, "]"));
Some(Some(x))
}
} else {
None
};
let constraints = try_notfirst!(CreateTableColumnConstraint::parse_series(tokens));
Ok(CreateTableColumn {
column_name: column_name,
type_name: type_name,
type_size: type_size,
type_array_size: type_array_size,
constraints: constraints
})
}
}
impl Rule for CreateTableStatement {
type Output = CreateTableStatement;
fn parse(tokens: &mut Tokens) -> RuleResult<CreateTableStatement> {
try!(tokens.pop_expecting(&Token::Table, "TABLE"));
let table = try_notfirst!(Table::parse(tokens));
try_notfirst!(tokens.pop_expecting(&Token::LeftParen, "( after table name"));
let columns = try_notfirst!(CreateTableColumn::parse_comma_delimited(tokens));
try_notfirst!(tokens.pop_expecting(&Token::RightParen, ") after table columns and constraints"));
Ok(CreateTableStatement {
table: table,
columns: columns
})
}
}
impl Rule for CreateStatement {
type Output = CreateStatement;
fn parse(tokens: &mut Tokens) -> RuleResult<CreateStatement> {
try!(tokens.pop_expecting(&Token::Create, "CREATE"));
if let Some(stmt) = try_notfirst!(CreateTableStatement::parse_lookahead(tokens)) {
Ok(CreateStatement::Table(stmt))
} else {
Err(tokens.expecting("TABLE"))
}
}
}
impl Rule for Statement {
type Output = Option<Statement>;
fn parse(tokens: &mut Tokens) -> RuleResult<Option<Statement>> {
let statement = if let Some(select) = try!(SelectStatement::parse_lookahead(tokens)) {
Some(Statement::Select(select))
} else if let Some(insert) = try!(InsertStatement::parse_lookahead(tokens)) {
Some(Statement::Insert(insert))
} else if let Some(create) = try!(CreateStatement::parse_lookahead(tokens)) {
Some(Statement::Create(create))
} else {
None
};
if let Some(statement) = statement {
try_notfirst!(tokens.pop_expecting(&Token::Semicolon, "semicolon"));
Ok(Some(statement))
} else {
try!(tokens.pop_expecting(&Token::Semicolon, "semicolon"));
Ok(None)
}
}
}
pub fn parse(tokens_slice: &[Token]) -> Result<Vec<Statement>, RuleError> {
let mut tokens = Tokens::new(tokens_slice);
let mut statements = Vec::new();
while let Some(value) = try!(Statement::parse_lookahead(&mut tokens)) {
if let Some(stmt) = value {
statements.push(stmt);
}
}
Ok(statements)
}

View File

@ -0,0 +1,147 @@
use super::super::lexer::Token;
use super::{Rule, RuleExt, RuleError, RuleResult};
use super::{rule_result_not_first};
#[derive(Copy)]
pub struct Tokens<'a> {
tokens: &'a [Token]
}
impl<'a> Tokens<'a> {
fn peek_clone(&self) -> Option<Token> {
if self.tokens.len() > 0 {
Some(self.tokens[0].clone())
} else {
None
}
}
pub fn new(tokens: &'a [Token]) -> Tokens<'a> {
Tokens {
tokens: tokens
}
}
pub fn expecting(&self, expecting_message: &'static str) -> RuleError {
RuleError::ExpectingFirst(expecting_message, self.peek_clone())
}
pub fn pop_expecting(&mut self, token: &Token, expecting_message: &'static str) -> RuleResult<()> {
if self.pop_if_token(token) { Ok(()) }
else { Err(self.expecting(expecting_message)) }
}
#[must_use]
pub fn pop_if_token(&mut self, token: &Token) -> bool {
if self.tokens.len() > 0 {
if &self.tokens[0] == token {
self.tokens = &self.tokens[1..];
true
} else {
false
}
} else {
false
}
}
#[must_use]
pub fn pop_if_number(&mut self) -> Option<String> {
if self.tokens.len() > 0 {
let token = &self.tokens[0];
if let &Token::Number(ref s) = token {
let ident = s.clone();
self.tokens = &self.tokens[1..];
Some(ident)
} else {
None
}
} else {
None
}
}
#[must_use]
pub fn pop_if_string_literal(&mut self) -> Option<String> {
if self.tokens.len() > 0 {
let token = &self.tokens[0];
if let &Token::StringLiteral(ref s) = token {
let ident = s.clone();
self.tokens = &self.tokens[1..];
Some(ident)
} else {
None
}
} else {
None
}
}
pub fn pop_if_ident(&mut self) -> Option<String> {
if self.tokens.len() > 0 {
let token = &self.tokens[0];
if let &Token::Ident(ref s) = token {
let ident = s.clone();
self.tokens = &self.tokens[1..];
Some(ident)
} else {
None
}
} else {
None
}
}
pub fn pop_ident_expecting(&mut self, expecting_message: &'static str) -> RuleResult<String> {
if self.tokens.len() > 0 {
let token = &self.tokens[0];
if let &Token::Ident(ref s) = token {
let ident = s.clone();
self.tokens = &self.tokens[1..];
Ok(ident)
} else {
Err(self.expecting(expecting_message))
}
} else {
Err(self.expecting(expecting_message))
}
}
pub fn pop_number_expecting(&mut self, expecting_message: &'static str) -> RuleResult<String> {
if self.tokens.len() > 0 {
let token = &self.tokens[0];
if let &Token::Number(ref s) = token {
let ident = s.clone();
self.tokens = &self.tokens[1..];
Ok(ident)
} else {
Err(self.expecting(expecting_message))
}
} else {
Err(self.expecting(expecting_message))
}
}
pub fn pop(&mut self) -> RuleResult<&'a Token> {
if self.tokens.len() > 0 {
let token = &self.tokens[0];
self.tokens = &self.tokens[1..];
Ok(token)
} else {
Err(RuleError::NoMoreTokens)
}
}
pub fn peek(&self) -> RuleResult<&'a Token> {
if self.tokens.len() > 0 {
Ok(&self.tokens[0])
} else {
Err(RuleError::NoMoreTokens)
}
}
}