Skip to content

Commit

Permalink
feat: CSV-writer escape carriage return (#15399)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Mar 30, 2024
1 parent be41697 commit 345ca75
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
28 changes: 21 additions & 7 deletions crates/polars-io/src/csv/write_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use arrow::legacy::time_zone::Tz;
use arrow::temporal_conversions;
#[cfg(feature = "timezones")]
use chrono::TimeZone;
use memchr::{memchr, memchr2};
use memchr::{memchr3, memmem};
use polars_core::prelude::*;
use polars_core::series::SeriesIter;
use polars_core::POOL;
Expand All @@ -20,15 +20,23 @@ use serde::{Deserialize, Serialize};

use super::write::QuoteStyle;

fn fmt_and_escape_str(f: &mut Vec<u8>, v: &str, options: &SerializeOptions) -> std::io::Result<()> {
const LF: u8 = b'\n';
const CR: u8 = b'\r';

fn fmt_and_escape_str(
f: &mut Vec<u8>,
v: &str,
options: &SerializeOptions,
find_quotes: &memmem::Finder,
) -> std::io::Result<()> {
if options.quote_style == QuoteStyle::Never {
return write!(f, "{v}");
}
let quote = options.quote_char as char;
if v.is_empty() {
return write!(f, "{quote}{quote}");
}
let needs_escaping = memchr(options.quote_char, v.as_bytes()).is_some();
let needs_escaping = find_quotes.find(v.as_bytes()).is_some();
if needs_escaping {
let replaced = unsafe {
// Replace from single quote " to double quote "".
Expand All @@ -41,7 +49,7 @@ fn fmt_and_escape_str(f: &mut Vec<u8>, v: &str, options: &SerializeOptions) -> s
}
let surround_with_quotes = match options.quote_style {
QuoteStyle::Always | QuoteStyle::NonNumeric => true,
QuoteStyle::Necessary => memchr2(options.separator, b'\n', v.as_bytes()).is_some(),
QuoteStyle::Necessary => memchr3(options.separator, LF, CR, v.as_bytes()).is_some(),
QuoteStyle::Never => false,
};

Expand Down Expand Up @@ -72,17 +80,18 @@ unsafe fn write_any_value(
datetime_formats: &[&str],
time_zones: &[Option<Tz>],
i: usize,
find_quotes: &memmem::Finder,
) -> PolarsResult<()> {
match value {
// First do the string-like types as they know how to deal with quoting.
AnyValue::String(v) => {
fmt_and_escape_str(f, v, options)?;
fmt_and_escape_str(f, v, options, find_quotes)?;
Ok(())
},
#[cfg(feature = "dtype-categorical")]
AnyValue::Categorical(idx, rev_map, _) | AnyValue::Enum(idx, rev_map, _) => {
let v = rev_map.get(idx);
fmt_and_escape_str(f, v, options)?;
fmt_and_escape_str(f, v, options, find_quotes)?;
Ok(())
},
_ => {
Expand Down Expand Up @@ -410,6 +419,8 @@ pub(crate) fn write<W: Write>(

let last_ptr = &col_iters[col_iters.len() - 1] as *const SeriesIter;
let mut finished = false;
let binding = &[options.quote_char];
let find_quotes = memmem::Finder::new(binding);
// loop rows
while !finished {
for (i, col) in &mut col_iters.iter_mut().enumerate() {
Expand All @@ -422,6 +433,7 @@ pub(crate) fn write<W: Write>(
&datetime_formats,
&time_zones,
i,
&find_quotes,
)?;
},
None => {
Expand Down Expand Up @@ -475,8 +487,10 @@ pub(crate) fn write_header<W: Write>(
let mut escaped_names: Vec<String> = Vec::with_capacity(names.len());
let mut nm: Vec<u8> = vec![];

let binding = &[options.quote_char];
let find_quotes = memmem::Finder::new(binding);
for name in names {
fmt_and_escape_str(&mut nm, name, options)?;
fmt_and_escape_str(&mut nm, name, options, &find_quotes)?;
unsafe {
// SAFETY: we know headers will be valid UTF-8 at this point
escaped_names.push(std::str::from_utf8_unchecked(&nm).to_string());
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,3 +1969,11 @@ def test_read_csv_single_column(columns: list[str] | str) -> None:
def test_csv_invalid_escape_utf8_14960() -> None:
with pytest.raises(pl.ComputeError, match=r"field is not properly escaped"):
pl.read_csv('col1\n""•'.encode())


def test_csv_escape_cf_15349() -> None:
f = io.BytesIO()
df = pl.DataFrame({"test": ["normal", "with\rcr"]})
df.write_csv(f)
f.seek(0)
assert f.read() == b'test\nnormal\n"with\rcr"\n'

0 comments on commit 345ca75

Please sign in to comment.