From 697de36a3f81d61d2017557f48bb54f840ab3ba3 Mon Sep 17 00:00:00 2001 From: ShreyeshArangath Date: Fri, 13 Feb 2026 20:05:37 -0800 Subject: [PATCH 1/4] feat: add Python bindings for accessing ExecutionMetrics --- python/datafusion/__init__.py | 4 +- python/datafusion/plan.py | 106 ++++++++++++++++++++++++ python/tests/test_plans.py | 152 +++++++++++++++++++++++++++++++++- src/dataframe.rs | 44 +++++++++- src/lib.rs | 3 + src/metrics.rs | 143 ++++++++++++++++++++++++++++++++ src/physical_plan.rs | 6 ++ 7 files changed, 452 insertions(+), 6 deletions(-) create mode 100644 src/metrics.rs diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 2e6f81166..f0314d8dd 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -55,7 +55,7 @@ from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .options import CsvReadOptions -from .plan import ExecutionPlan, LogicalPlan +from .plan import ExecutionPlan, LogicalPlan, Metric, MetricsSet from .record_batch import RecordBatch, RecordBatchStream from .user_defined import ( Accumulator, @@ -85,6 +85,8 @@ "Expr", "InsertOp", "LogicalPlan", + "Metric", + "MetricsSet", "ParquetColumnOptions", "ParquetWriterOptions", "RecordBatch", diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index fb54fd624..d46ff1a00 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -29,6 +29,8 @@ __all__ = [ "ExecutionPlan", "LogicalPlan", + "Metric", + "MetricsSet", ] @@ -151,3 +153,107 @@ def to_proto(self) -> bytes: Tables created in memory from record batches are currently not supported. """ return self._raw_plan.to_proto() + + def metrics(self) -> MetricsSet | None: + """Return metrics for this plan node after execution, or None if unavailable.""" + raw = self._raw_plan.metrics() + if raw is None: + return None + return MetricsSet(raw) + + def collect_metrics(self) -> list[tuple[str, MetricsSet]]: + """Walk the plan tree and collect metrics from all operators. + + Returns a list of (operator_name, MetricsSet) tuples. + """ + result: list[tuple[str, MetricsSet]] = [] + + def _walk(node: ExecutionPlan) -> None: + ms = node.metrics() + if ms is not None: + result.append((node.display(), ms)) + for child in node.children(): + _walk(child) + + _walk(self) + return result + + +class MetricsSet: + """A set of metrics for a single execution plan operator. + + Provides both individual metric access and convenience aggregations + across partitions. + """ + + def __init__(self, raw: df_internal.MetricsSet) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + def metrics(self) -> list[Metric]: + """Return all individual metrics in this set.""" + return [Metric(m) for m in self._raw.metrics()] + + @property + def output_rows(self) -> int | None: + """Sum of output_rows across all partitions.""" + return self._raw.output_rows() + + @property + def elapsed_compute(self) -> int | None: + """Sum of elapsed_compute across all partitions, in nanoseconds.""" + return self._raw.elapsed_compute() + + @property + def spill_count(self) -> int | None: + """Sum of spill_count across all partitions.""" + return self._raw.spill_count() + + @property + def spilled_bytes(self) -> int | None: + """Sum of spilled_bytes across all partitions.""" + return self._raw.spilled_bytes() + + @property + def spilled_rows(self) -> int | None: + """Sum of spilled_rows across all partitions.""" + return self._raw.spilled_rows() + + def sum_by_name(self, name: str) -> int | None: + """Return the sum of metrics matching the given name.""" + return self._raw.sum_by_name(name) + + def __repr__(self) -> str: + """Return a string representation of the metrics set.""" + return repr(self._raw) + + +class Metric: + """A single execution metric with name, value, partition, and labels.""" + + def __init__(self, raw: df_internal.Metric) -> None: + """This constructor should not be called by the end user.""" + self._raw = raw + + @property + def name(self) -> str: + """The name of this metric (e.g. ``output_rows``).""" + return self._raw.name + + @property + def value(self) -> int | None: + """The numeric value of this metric, or None for non-numeric types.""" + return self._raw.value + + @property + def partition(self) -> int | None: + """The partition this metric applies to, or None if global.""" + return self._raw.partition + + def labels(self) -> dict[str, str]: + """Return the labels associated with this metric.""" + return self._raw.labels() + + def __repr__(self) -> str: + """Return a string representation of the metric.""" + return repr(self._raw) diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 396acbe97..05a721db1 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -16,7 +16,13 @@ # under the License. import pytest -from datafusion import ExecutionPlan, LogicalPlan, SessionContext +from datafusion import ( + ExecutionPlan, + LogicalPlan, + Metric, + MetricsSet, + SessionContext, +) # Note: We must use CSV because memory tables are currently not supported for @@ -40,3 +46,147 @@ def test_logical_plan_to_proto(ctx, df) -> None: execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes) assert str(original_execution_plan) == str(execution_plan) + + +def test_execution_plan_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect() + plan = df.execution_plan() + + found_metrics = False + + def _check(node): + nonlocal found_metrics + ms = node.metrics() + if ms is not None and ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + for child in node.children(): + _check(child) + + _check(plan) + assert found_metrics + + +def test_metric_properties() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect() + plan = df.execution_plan() + + for _, ms in plan.collect_metrics(): + for metric in ms.metrics(): + assert isinstance(metric, Metric) + assert isinstance(metric.name, str) + assert len(metric.name) > 0 + assert metric.partition is None or isinstance(metric.partition, int) + assert isinstance(metric.labels(), dict) + return + pytest.skip("No metrics found") + + +def test_metrics_tree_walk() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'a'), (4, 'b')") + df = ctx.sql("SELECT column2, COUNT(*) FROM t GROUP BY column2") + + df.collect() + plan = df.execution_plan() + + results = plan.collect_metrics() + assert len(results) >= 2 + for name, ms in results: + assert isinstance(name, str) + assert isinstance(ms, MetricsSet) + + +def test_no_metrics_before_execution() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)") + df = ctx.sql("SELECT * FROM t") + plan = df.execution_plan() + ms = plan.metrics() + assert ms is None or ms.output_rows is None or ms.output_rows == 0 + + +def test_metrics_repr() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)") + df = ctx.sql("SELECT * FROM t") + + df.collect() + plan = df.execution_plan() + + for _, ms in plan.collect_metrics(): + r = repr(ms) + assert isinstance(r, str) + for metric in ms.metrics(): + mr = repr(metric) + assert isinstance(mr, str) + assert len(mr) > 0 + return + pytest.skip("No metrics found") + + +def test_collect_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + partitions = df.collect_partitioned() + plan = df.execution_plan() + assert len(partitions) == plan.partition_count + + # Metrics should be populated after collecting + found_metrics = False + for _, ms in plan.collect_metrics(): + if ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + assert found_metrics + + +def test_execute_stream_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + stream = df.execute_stream() + + # Consume the stream (iterates over RecordBatches) + batches = list(stream) + assert len(batches) >= 1 + + # Metrics should be populated after consuming the stream + plan = df.execution_plan() + found_metrics = False + for name, ms in plan.collect_metrics(): + assert isinstance(name, str) + assert isinstance(ms, MetricsSet) + if ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + assert found_metrics + + +def test_execute_stream_partitioned_metrics() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + streams = df.execute_stream_partitioned() + + # Consume all partition streams + for stream in streams: + for _ in stream: + pass + + # Metrics should be populated (FilterExec reports output_rows) + plan = df.execution_plan() + found_metrics = False + for _, ms in plan.collect_metrics(): + if ms.output_rows is not None and ms.output_rows > 0: + found_metrics = True + assert found_metrics diff --git a/src/dataframe.rs b/src/dataframe.rs index 53fab58c6..3f8f3c94f 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -49,6 +49,13 @@ use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; use crate::common::data_type::PyScalarValue; +use datafusion::physical_plan::{ + ExecutionPlan as DFExecutionPlan, + collect as df_collect, + collect_partitioned as df_collect_partitioned, + execute_stream as df_execute_stream, + execute_stream_partitioned as df_execute_stream_partitioned, +}; use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err}; use crate::expr::PyExpr; use crate::expr::sort_expr::{PySortExpr, to_sort_expressions}; @@ -289,6 +296,9 @@ pub struct PyDataFrame { // In IPython environment cache batches between __repr__ and _repr_html_ calls. batches: SharedCachedBatches, + + // Cache the last physical plan so that metrics are available after execution. + last_plan: Arc>>>, } impl PyDataFrame { @@ -297,6 +307,7 @@ impl PyDataFrame { Self { df: Arc::new(df), batches: Arc::new(Mutex::new(None)), + last_plan: Arc::new(Mutex::new(None)), } } @@ -626,7 +637,12 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect<'py>(&self, py: Python<'py>) -> PyResult>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect())? + let df = self.df.as_ref().clone(); + let plan = wait_for_future(py, df.create_physical_plan())? + .map_err(PyDataFusionError::from)?; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let batches = wait_for_future(py, df_collect(plan, task_ctx))? .map_err(PyDataFusionError::from)?; // cannot use PyResult> return type due to // https://github.com/PyO3/pyo3/issues/1813 @@ -642,7 +658,12 @@ impl PyDataFrame { /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch /// maintaining the input partitioning. fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult>>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())? + let df = self.df.as_ref().clone(); + let plan = wait_for_future(py, df.create_physical_plan())? + .map_err(PyDataFusionError::from)?; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))? .map_err(PyDataFusionError::from)?; batches @@ -802,7 +823,13 @@ impl PyDataFrame { } /// Get the execution plan for this `DataFrame` + /// + /// If the DataFrame has already been executed (e.g. via `collect()`), + /// returns the cached plan which includes populated metrics. fn execution_plan(&self, py: Python) -> PyDataFusionResult { + if let Some(plan) = self.last_plan.lock().as_ref() { + return Ok(PyExecutionPlan::new(Arc::clone(plan))); + } let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??; Ok(plan.into()) } @@ -1127,13 +1154,22 @@ impl PyDataFrame { fn execute_stream(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); - let stream = spawn_future(py, async move { df.execute_stream().await })?; + let plan = wait_for_future(py, df.create_physical_plan())??; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?; Ok(PyRecordBatchStream::new(stream)) } fn execute_stream_partitioned(&self, py: Python) -> PyResult> { let df = self.df.as_ref().clone(); - let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?; + let plan = wait_for_future(py, df.create_physical_plan())? + .map_err(PyDataFusionError::from)?; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let streams = spawn_future(py, async move { + df_execute_stream_partitioned(plan, task_ctx) + })?; Ok(streams.into_iter().map(PyRecordBatchStream::new).collect()) } diff --git a/src/lib.rs b/src/lib.rs index 081366b20..7c21ae95c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ pub mod errors; pub mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; +pub mod metrics; mod options; pub mod physical_plan; mod pyarrow_filter_expression; @@ -96,6 +97,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 000000000..e333ea791 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Metric}; +use pyo3::prelude::*; + +#[pyclass(frozen, name = "MetricsSet", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetricsSet { + metrics: MetricsSet, +} + +impl PyMetricsSet { + pub fn new(metrics: MetricsSet) -> Self { + Self { metrics } + } +} + +#[pymethods] +impl PyMetricsSet { + /// Returns all individual metrics in this set. + fn metrics(&self) -> Vec { + self.metrics + .iter() + .map(|m| PyMetric::new(Arc::clone(m))) + .collect() + } + + /// Returns the sum of all `output_rows` metrics, or None if not present. + fn output_rows(&self) -> Option { + self.metrics.output_rows() + } + + /// Returns the sum of all `elapsed_compute` metrics in nanoseconds, or None if not present. + fn elapsed_compute(&self) -> Option { + self.metrics.elapsed_compute() + } + + /// Returns the sum of all `spill_count` metrics, or None if not present. + fn spill_count(&self) -> Option { + self.metrics.spill_count() + } + + /// Returns the sum of all `spilled_bytes` metrics, or None if not present. + fn spilled_bytes(&self) -> Option { + self.metrics.spilled_bytes() + } + + /// Returns the sum of all `spilled_rows` metrics, or None if not present. + fn spilled_rows(&self) -> Option { + self.metrics.spilled_rows() + } + + /// Returns the sum of metrics matching the given name. + fn sum_by_name(&self, name: &str) -> Option { + self.metrics.sum_by_name(name).map(|v| v.as_usize()) + } + + fn __repr__(&self) -> String { + format!("{}", self.metrics) + } +} + +#[pyclass(frozen, name = "Metric", module = "datafusion")] +#[derive(Debug, Clone)] +pub struct PyMetric { + metric: Arc, +} + +impl PyMetric { + pub fn new(metric: Arc) -> Self { + Self { metric } + } +} + +#[pymethods] +impl PyMetric { + /// Returns the name of this metric. + #[getter] + fn name(&self) -> String { + self.metric.value().name().to_string() + } + + /// Returns the numeric value of this metric, or None for non-numeric types. + #[getter] + fn value(&self) -> Option { + match self.metric.value() { + MetricValue::OutputRows(c) => Some(c.value()), + MetricValue::OutputBytes(c) => Some(c.value()), + MetricValue::ElapsedCompute(t) => Some(t.value()), + MetricValue::SpillCount(c) => Some(c.value()), + MetricValue::SpilledBytes(c) => Some(c.value()), + MetricValue::SpilledRows(c) => Some(c.value()), + MetricValue::CurrentMemoryUsage(g) => Some(g.value()), + MetricValue::Count { count, .. } => Some(count.value()), + MetricValue::Gauge { gauge, .. } => Some(gauge.value()), + MetricValue::Time { time, .. } => Some(time.value()), + MetricValue::StartTimestamp(ts) => { + ts.value().and_then(|dt| dt.timestamp_nanos_opt().map(|n| n as usize)) + } + MetricValue::EndTimestamp(ts) => { + ts.value().and_then(|dt| dt.timestamp_nanos_opt().map(|n| n as usize)) + } + _ => None, + } + } + + /// Returns the partition this metric is for, or None if it applies to all partitions. + #[getter] + fn partition(&self) -> Option { + self.metric.partition() + } + + /// Returns the labels associated with this metric as a dict. + fn labels(&self) -> HashMap { + self.metric + .labels() + .iter() + .map(|l| (l.name().to_string(), l.value().to_string())) + .collect() + } + + fn __repr__(&self) -> String { + format!("{}", self.metric.value()) + } +} diff --git a/src/physical_plan.rs b/src/physical_plan.rs index 0069e5e6e..319d27efe 100644 --- a/src/physical_plan.rs +++ b/src/physical_plan.rs @@ -26,6 +26,7 @@ use pyo3::types::PyBytes; use crate::context::PySessionContext; use crate::errors::PyDataFusionResult; +use crate::metrics::PyMetricsSet; #[pyclass(frozen, name = "ExecutionPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -90,6 +91,11 @@ impl PyExecutionPlan { Ok(Self::new(plan)) } + /// Returns metrics for this plan node after execution, or None if unavailable. + pub fn metrics(&self) -> Option { + self.plan.metrics().map(PyMetricsSet::new) + } + fn __repr__(&self) -> String { self.display_indent() } From 0a57da6d4c53f3285ffc58f3376a1d52d25c4533 Mon Sep 17 00:00:00 2001 From: ShreyeshArangath Date: Sat, 14 Feb 2026 18:00:45 -0800 Subject: [PATCH 2/4] test: imporve tests --- python/tests/test_plans.py | 80 ++++++++------------------------------ 1 file changed, 17 insertions(+), 63 deletions(-) diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 05a721db1..d3525b08c 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -48,25 +48,21 @@ def test_logical_plan_to_proto(ctx, df) -> None: assert str(original_execution_plan) == str(execution_plan) -def test_execution_plan_metrics() -> None: +def test_metrics_tree_walk() -> None: ctx = SessionContext() ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") df = ctx.sql("SELECT * FROM t WHERE column1 > 1") - df.collect() plan = df.execution_plan() + results = plan.collect_metrics() + assert len(results) >= 1 found_metrics = False - - def _check(node): - nonlocal found_metrics - ms = node.metrics() - if ms is not None and ms.output_rows is not None and ms.output_rows > 0: + for name, ms in results: + assert isinstance(name, str) + assert isinstance(ms, MetricsSet) + if ms.output_rows is not None and ms.output_rows > 0: found_metrics = True - for child in node.children(): - _check(child) - - _check(plan) assert found_metrics @@ -74,36 +70,25 @@ def test_metric_properties() -> None: ctx = SessionContext() ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") df = ctx.sql("SELECT * FROM t WHERE column1 > 1") - df.collect() plan = df.execution_plan() for _, ms in plan.collect_metrics(): + r = repr(ms) + assert isinstance(r, str) for metric in ms.metrics(): assert isinstance(metric, Metric) assert isinstance(metric.name, str) assert len(metric.name) > 0 assert metric.partition is None or isinstance(metric.partition, int) assert isinstance(metric.labels(), dict) + mr = repr(metric) + assert isinstance(mr, str) + assert len(mr) > 0 return pytest.skip("No metrics found") -def test_metrics_tree_walk() -> None: - ctx = SessionContext() - ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'a'), (4, 'b')") - df = ctx.sql("SELECT column2, COUNT(*) FROM t GROUP BY column2") - - df.collect() - plan = df.execution_plan() - - results = plan.collect_metrics() - assert len(results) >= 2 - for name, ms in results: - assert isinstance(name, str) - assert isinstance(ms, MetricsSet) - - def test_no_metrics_before_execution() -> None: ctx = SessionContext() ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)") @@ -113,35 +98,14 @@ def test_no_metrics_before_execution() -> None: assert ms is None or ms.output_rows is None or ms.output_rows == 0 -def test_metrics_repr() -> None: - ctx = SessionContext() - ctx.sql("CREATE TABLE t AS VALUES (1), (2), (3)") - df = ctx.sql("SELECT * FROM t") - - df.collect() - plan = df.execution_plan() - - for _, ms in plan.collect_metrics(): - r = repr(ms) - assert isinstance(r, str) - for metric in ms.metrics(): - mr = repr(metric) - assert isinstance(mr, str) - assert len(mr) > 0 - return - pytest.skip("No metrics found") - - def test_collect_partitioned_metrics() -> None: ctx = SessionContext() ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") df = ctx.sql("SELECT * FROM t WHERE column1 > 1") - partitions = df.collect_partitioned() + df.collect_partitioned() plan = df.execution_plan() - assert len(partitions) == plan.partition_count - # Metrics should be populated after collecting found_metrics = False for _, ms in plan.collect_metrics(): if ms.output_rows is not None and ms.output_rows > 0: @@ -154,18 +118,12 @@ def test_execute_stream_metrics() -> None: ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") df = ctx.sql("SELECT * FROM t WHERE column1 > 1") - stream = df.execute_stream() - - # Consume the stream (iterates over RecordBatches) - batches = list(stream) - assert len(batches) >= 1 + for _ in df.execute_stream(): + pass - # Metrics should be populated after consuming the stream plan = df.execution_plan() found_metrics = False - for name, ms in plan.collect_metrics(): - assert isinstance(name, str) - assert isinstance(ms, MetricsSet) + for _, ms in plan.collect_metrics(): if ms.output_rows is not None and ms.output_rows > 0: found_metrics = True assert found_metrics @@ -176,14 +134,10 @@ def test_execute_stream_partitioned_metrics() -> None: ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") df = ctx.sql("SELECT * FROM t WHERE column1 > 1") - streams = df.execute_stream_partitioned() - - # Consume all partition streams - for stream in streams: + for stream in df.execute_stream_partitioned(): for _ in stream: pass - # Metrics should be populated (FilterExec reports output_rows) plan = df.execution_plan() found_metrics = False for _, ms in plan.collect_metrics(): From e1d0c81a78e2abdeb19f9f13efc9a4c4f11f3382 Mon Sep 17 00:00:00 2001 From: ShreyeshArangath Date: Thu, 19 Mar 2026 13:33:09 -0700 Subject: [PATCH 3/4] first round of reviews --- .../dataframe/execution-metrics.rst | 164 ++++++++++++++++++ docs/source/user-guide/dataframe/index.rst | 9 + python/datafusion/plan.py | 77 ++++++-- python/tests/test_plans.py | 52 +++--- src/dataframe.rs | 38 ++-- src/metrics.rs | 20 +-- 6 files changed, 296 insertions(+), 64 deletions(-) create mode 100644 docs/source/user-guide/dataframe/execution-metrics.rst diff --git a/docs/source/user-guide/dataframe/execution-metrics.rst b/docs/source/user-guide/dataframe/execution-metrics.rst new file mode 100644 index 000000000..42262b036 --- /dev/null +++ b/docs/source/user-guide/dataframe/execution-metrics.rst @@ -0,0 +1,164 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. _execution_metrics: + +Execution Metrics +================= + +Overview +-------- + +When DataFusion executes a query it compiles the logical plan into a tree of +*physical plan operators* (e.g. ``FilterExec``, ``ProjectionExec``, +``HashAggregateExec``). Each operator can record runtime statistics while it +runs. These statistics are called **execution metrics**. + +Typical metrics include: + +- **output_rows** – number of rows produced by the operator +- **elapsed_compute** – total CPU time (nanoseconds) spent inside the operator +- **spill_count** – number of times the operator spilled data to disk +- **spilled_bytes** – total bytes written to disk during spills +- **spilled_rows** – total rows written to disk during spills + +Metrics are collected *per-partition*: DataFusion may execute each operator +in parallel across several partitions. The convenience properties on +:py:class:`~datafusion.MetricsSet` (e.g. ``output_rows``, ``elapsed_compute``) +automatically sum the named metric across **all** partitions, giving a single +aggregate value for the operator as a whole. You can also access the raw +per-partition :py:class:`~datafusion.Metric` objects via +:py:meth:`~datafusion.MetricsSet.metrics`. + +When Are Metrics Available? +--------------------------- + +Metrics are populated only **after** the DataFrame has been executed. +Execution is triggered by any of the terminal operations: + +- :py:meth:`~datafusion.DataFrame.collect` +- :py:meth:`~datafusion.DataFrame.collect_partitioned` +- :py:meth:`~datafusion.DataFrame.execute_stream` +- :py:meth:`~datafusion.DataFrame.execute_stream_partitioned` + +Calling :py:meth:`~datafusion.ExecutionPlan.collect_metrics` before execution +will return entries with empty (or ``None``) metric sets because the operators +have not run yet. + +Reading the Physical Plan Tree +-------------------------------- + +:py:meth:`~datafusion.DataFrame.execution_plan` returns the root +:py:class:`~datafusion.ExecutionPlan` node of the physical plan tree. The tree +mirrors the operator pipeline: the root is typically a projection or +coalescing node; its children are filters, aggregates, scans, etc. + +The ``operator_name`` string returned by +:py:meth:`~datafusion.ExecutionPlan.collect_metrics` is the *display* name of +the node, for example ``"FilterExec: column1@0 > 1"``. This is the same string +you would see when calling ``plan.display()``. + +Available Metrics +----------------- + +The following metrics are directly accessible as properties on +:py:class:`~datafusion.MetricsSet`: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Property + - Description + * - ``output_rows`` + - Number of rows emitted by the operator (summed across partitions). + * - ``elapsed_compute`` + - CPU time in nanoseconds spent inside the operator's execute loop + (summed across partitions). + * - ``spill_count`` + - Number of spill-to-disk events due to memory pressure (summed across + partitions). + * - ``spilled_bytes`` + - Total bytes written to disk during spills (summed across partitions). + * - ``spilled_rows`` + - Total rows written to disk during spills (summed across partitions). + +Any metric not listed above can be accessed via +:py:meth:`~datafusion.MetricsSet.sum_by_name`, or by iterating over the raw +:py:class:`~datafusion.Metric` objects returned by +:py:meth:`~datafusion.MetricsSet.metrics`. + +Labels +------ + +A :py:class:`~datafusion.Metric` may carry *labels*: key/value pairs that +provide additional context. For example, some operators tag their output +metrics with an ``output_type`` label to distinguish between intermediate and +final output: + +.. code-block:: python + + for metric in metrics_set.metrics(): + print(metric.name, metric.labels()) + # output_rows {'output_type': 'final'} + +Labels are operator-specific; most metrics have no labels. + +End-to-End Example +------------------ + +.. code-block:: python + + from datafusion import SessionContext + + ctx = SessionContext() + ctx.sql("CREATE TABLE sales AS VALUES (1, 100), (2, 200), (3, 50)") + + df = ctx.sql("SELECT * FROM sales WHERE column1 > 1") + + # Execute the query — this populates the metrics + results = df.collect() + + # Retrieve the physical plan with metrics + plan = df.execution_plan() + + # Walk every operator and print its metrics + for operator_name, ms in plan.collect_metrics(): + if ms.output_rows is not None: + print(f"{operator_name}") + print(f" output_rows = {ms.output_rows}") + print(f" elapsed_compute = {ms.elapsed_compute} ns") + + # Access raw per-partition metrics + for operator_name, ms in plan.collect_metrics(): + for metric in ms.metrics(): + print( + f" partition={metric.partition} " + f"{metric.name}={metric.value} " + f"labels={metric.labels()}" + ) + +API Reference +------------- + +- :py:class:`datafusion.ExecutionPlan` — physical plan node +- :py:meth:`datafusion.ExecutionPlan.collect_metrics` — walk the tree and + return ``(operator_name, MetricsSet)`` pairs +- :py:meth:`datafusion.ExecutionPlan.metrics` — return the + :py:class:`~datafusion.MetricsSet` for a single node +- :py:class:`datafusion.MetricsSet` — aggregated metrics for one operator +- :py:class:`datafusion.Metric` — a single per-partition metric value diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 510bcbc68..8475a7bd7 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -365,7 +365,16 @@ DataFusion provides many built-in functions for data manipulation: For a complete list of available functions, see the :py:mod:`datafusion.functions` module documentation. +Execution Metrics +----------------- + +After executing a DataFrame (via ``collect()``, ``execute_stream()``, etc.), +DataFusion populates per-operator runtime statistics such as row counts and +compute time. See :doc:`execution-metrics` for a full explanation and +worked example. + .. toctree:: :maxdepth: 1 rendering + execution-metrics diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index d46ff1a00..bac437835 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -162,9 +162,29 @@ def metrics(self) -> MetricsSet | None: return MetricsSet(raw) def collect_metrics(self) -> list[tuple[str, MetricsSet]]: - """Walk the plan tree and collect metrics from all operators. - - Returns a list of (operator_name, MetricsSet) tuples. + """Return runtime statistics for each step of the query execution. + + DataFusion executes a query as a pipeline of operators — for example a + data source scan, followed by a filter, followed by a projection. After + the DataFrame has been executed (via + :py:meth:`~datafusion.DataFrame.collect`, + :py:meth:`~datafusion.DataFrame.execute_stream`, etc.), each operator + records statistics such as how many rows it produced and how much CPU + time it consumed. + + Each entry in the returned list corresponds to one operator that + recorded metrics. The first element of the tuple is the operator's + description string — the same text shown by + :py:meth:`display_indent` — which identifies both the operator type + and its key parameters, for example ``"FilterExec: column1@0 > 1"`` + or ``"DataSourceExec: partitions=1"``. + + Returns: + A list of ``(description, MetricsSet)`` tuples ordered from the + outermost operator (top of the execution tree) down to the + data-source leaves. Only operators that recorded at least one + metric are included. Returns an empty list if called before the + DataFrame has been executed. """ result: list[tuple[str, MetricsSet]] = [] @@ -182,8 +202,11 @@ def _walk(node: ExecutionPlan) -> None: class MetricsSet: """A set of metrics for a single execution plan operator. - Provides both individual metric access and convenience aggregations - across partitions. + A physical plan operator runs independently across one or more partitions. + :py:meth:`metrics` returns the raw per-partition :py:class:`Metric` objects. + The convenience properties (:py:attr:`output_rows`, :py:attr:`elapsed_compute`, + etc.) automatically sum the named metric across *all* partitions, giving a + single aggregate value for the operator as a whole. """ def __init__(self, raw: df_internal.MetricsSet) -> None: @@ -201,12 +224,20 @@ def output_rows(self) -> int | None: @property def elapsed_compute(self) -> int | None: - """Sum of elapsed_compute across all partitions, in nanoseconds.""" + """Total CPU time (in nanoseconds) spent inside this operator's execute loop. + + Summed across all partitions. Returns ``None`` if no ``elapsed_compute`` + metric was recorded. + """ return self._raw.elapsed_compute() @property def spill_count(self) -> int | None: - """Sum of spill_count across all partitions.""" + """Number of times this operator spilled data to disk due to memory pressure. + + This is a count of spill events, not a byte count. Summed across all + partitions. Returns ``None`` if no ``spill_count`` metric was recorded. + """ return self._raw.spill_count() @property @@ -220,7 +251,14 @@ def spilled_rows(self) -> int | None: return self._raw.spilled_rows() def sum_by_name(self, name: str) -> int | None: - """Return the sum of metrics matching the given name.""" + """Sum the named metric across all partitions. + + Useful for accessing any metric not exposed as a first-class property. + Returns ``None`` if no metric with the given name was recorded. + + Args: + name: The metric name, e.g. ``"output_rows"`` or ``"elapsed_compute"``. + """ return self._raw.sum_by_name(name) def __repr__(self) -> str: @@ -242,16 +280,33 @@ def name(self) -> str: @property def value(self) -> int | None: - """The numeric value of this metric, or None for non-numeric types.""" + """The numeric value of this metric, or ``None`` when not representable. + + ``None`` is returned for metric types whose value has not yet been set + (e.g. ``StartTimestamp`` / ``EndTimestamp`` before the operator runs) + and for any metric variant whose value cannot be expressed as an integer. + Timestamp metrics, when available, are returned as nanoseconds since the + Unix epoch. + """ return self._raw.value @property def partition(self) -> int | None: - """The partition this metric applies to, or None if global.""" + """The 0-based partition index this metric applies to. + + Returns ``None`` for metrics that are not partition-specific (i.e. they + apply globally across all partitions of the operator). + """ return self._raw.partition def labels(self) -> dict[str, str]: - """Return the labels associated with this metric.""" + """Return the labels associated with this metric. + + Labels provide additional context for a metric. For example:: + + >>> metric.labels() + {'output_type': 'final'} + """ return self._raw.labels() def __repr__(self) -> str: diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index d3525b08c..926cb3a40 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import pytest from datafusion import ( ExecutionPlan, LogicalPlan, @@ -57,13 +56,14 @@ def test_metrics_tree_walk() -> None: results = plan.collect_metrics() assert len(results) >= 1 - found_metrics = False + output_rows_values = [] for name, ms in results: assert isinstance(name, str) assert isinstance(ms, MetricsSet) - if ms.output_rows is not None and ms.output_rows > 0: - found_metrics = True - assert found_metrics + if ms.output_rows is not None: + output_rows_values.append(ms.output_rows) + # The filter passes rows where column1 > 1, so exactly 2 rows from (1,'a'),(2,'b'),(3,'c') + assert 2 in output_rows_values def test_metric_properties() -> None: @@ -73,20 +73,22 @@ def test_metric_properties() -> None: df.collect() plan = df.execution_plan() + found_any_metric = False for _, ms in plan.collect_metrics(): r = repr(ms) assert isinstance(r, str) for metric in ms.metrics(): + found_any_metric = True assert isinstance(metric, Metric) assert isinstance(metric.name, str) assert len(metric.name) > 0 assert metric.partition is None or isinstance(metric.partition, int) + assert metric.value is None or isinstance(metric.value, int) assert isinstance(metric.labels(), dict) mr = repr(metric) assert isinstance(mr, str) assert len(mr) > 0 - return - pytest.skip("No metrics found") + assert found_any_metric, "Expected at least one metric after execution" def test_no_metrics_before_execution() -> None: @@ -95,7 +97,8 @@ def test_no_metrics_before_execution() -> None: df = ctx.sql("SELECT * FROM t") plan = df.execution_plan() ms = plan.metrics() - assert ms is None or ms.output_rows is None or ms.output_rows == 0 + # Before execution, the root plan node has no MetricsSet + assert ms is None def test_collect_partitioned_metrics() -> None: @@ -106,11 +109,12 @@ def test_collect_partitioned_metrics() -> None: df.collect_partitioned() plan = df.execution_plan() - found_metrics = False - for _, ms in plan.collect_metrics(): - if ms.output_rows is not None and ms.output_rows > 0: - found_metrics = True - assert found_metrics + output_rows_values = [ + ms.output_rows + for _, ms in plan.collect_metrics() + if ms.output_rows is not None + ] + assert 2 in output_rows_values def test_execute_stream_metrics() -> None: @@ -122,11 +126,12 @@ def test_execute_stream_metrics() -> None: pass plan = df.execution_plan() - found_metrics = False - for _, ms in plan.collect_metrics(): - if ms.output_rows is not None and ms.output_rows > 0: - found_metrics = True - assert found_metrics + output_rows_values = [ + ms.output_rows + for _, ms in plan.collect_metrics() + if ms.output_rows is not None + ] + assert 2 in output_rows_values def test_execute_stream_partitioned_metrics() -> None: @@ -139,8 +144,9 @@ def test_execute_stream_partitioned_metrics() -> None: pass plan = df.execution_plan() - found_metrics = False - for _, ms in plan.collect_metrics(): - if ms.output_rows is not None and ms.output_rows > 0: - found_metrics = True - assert found_metrics + output_rows_values = [ + ms.output_rows + for _, ms in plan.collect_metrics() + if ms.output_rows is not None + ] + assert 2 in output_rows_values diff --git a/src/dataframe.rs b/src/dataframe.rs index 3f8f3c94f..3d43809de 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -36,6 +36,7 @@ use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, Table use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; +use datafusion::execution::context::TaskContext; use datafusion::logical_expr::SortExpr; use datafusion::logical_expr::dml::InsertOp; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; @@ -379,6 +380,20 @@ impl PyDataFrame { Ok(html_str) } + /// Create the physical plan, cache it in `last_plan`, and return the plan together + /// with a task context. Centralises the repeated three-line pattern that appears in + /// `collect`, `collect_partitioned`, `execute_stream`, and `execute_stream_partitioned`. + fn create_and_cache_plan( + &self, + py: Python, + ) -> PyDataFusionResult<(Arc, Arc)> { + let df = self.df.as_ref().clone(); + let plan = wait_for_future(py, df.create_physical_plan())??; + *self.last_plan.lock() = Some(Arc::clone(&plan)); + let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + Ok((plan, task_ctx)) + } + async fn collect_column_inner(&self, column: &str) -> Result { let batches = self .df @@ -637,11 +652,7 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect<'py>(&self, py: Python<'py>) -> PyResult>> { - let df = self.df.as_ref().clone(); - let plan = wait_for_future(py, df.create_physical_plan())? - .map_err(PyDataFusionError::from)?; - *self.last_plan.lock() = Some(Arc::clone(&plan)); - let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let (plan, task_ctx) = self.create_and_cache_plan(py)?; let batches = wait_for_future(py, df_collect(plan, task_ctx))? .map_err(PyDataFusionError::from)?; // cannot use PyResult> return type due to @@ -658,11 +669,7 @@ impl PyDataFrame { /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch /// maintaining the input partitioning. fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult>>> { - let df = self.df.as_ref().clone(); - let plan = wait_for_future(py, df.create_physical_plan())? - .map_err(PyDataFusionError::from)?; - *self.last_plan.lock() = Some(Arc::clone(&plan)); - let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let (plan, task_ctx) = self.create_and_cache_plan(py)?; let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))? .map_err(PyDataFusionError::from)?; @@ -1153,20 +1160,13 @@ impl PyDataFrame { } fn execute_stream(&self, py: Python) -> PyDataFusionResult { - let df = self.df.as_ref().clone(); - let plan = wait_for_future(py, df.create_physical_plan())??; - *self.last_plan.lock() = Some(Arc::clone(&plan)); - let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let (plan, task_ctx) = self.create_and_cache_plan(py)?; let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?; Ok(PyRecordBatchStream::new(stream)) } fn execute_stream_partitioned(&self, py: Python) -> PyResult> { - let df = self.df.as_ref().clone(); - let plan = wait_for_future(py, df.create_physical_plan())? - .map_err(PyDataFusionError::from)?; - *self.last_plan.lock() = Some(Arc::clone(&plan)); - let task_ctx = Arc::new(self.df.as_ref().task_ctx()); + let (plan, task_ctx) = self.create_and_cache_plan(py)?; let streams = spawn_future(py, async move { df_execute_stream_partitioned(plan, task_ctx) })?; diff --git a/src/metrics.rs b/src/metrics.rs index e333ea791..8cd531a88 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -35,7 +35,6 @@ impl PyMetricsSet { #[pymethods] impl PyMetricsSet { - /// Returns all individual metrics in this set. fn metrics(&self) -> Vec { self.metrics .iter() @@ -43,32 +42,26 @@ impl PyMetricsSet { .collect() } - /// Returns the sum of all `output_rows` metrics, or None if not present. fn output_rows(&self) -> Option { self.metrics.output_rows() } - /// Returns the sum of all `elapsed_compute` metrics in nanoseconds, or None if not present. fn elapsed_compute(&self) -> Option { self.metrics.elapsed_compute() } - /// Returns the sum of all `spill_count` metrics, or None if not present. fn spill_count(&self) -> Option { self.metrics.spill_count() } - /// Returns the sum of all `spilled_bytes` metrics, or None if not present. fn spilled_bytes(&self) -> Option { self.metrics.spilled_bytes() } - /// Returns the sum of all `spilled_rows` metrics, or None if not present. fn spilled_rows(&self) -> Option { self.metrics.spilled_rows() } - /// Returns the sum of metrics matching the given name. fn sum_by_name(&self, name: &str) -> Option { self.metrics.sum_by_name(name).map(|v| v.as_usize()) } @@ -92,13 +85,20 @@ impl PyMetric { #[pymethods] impl PyMetric { - /// Returns the name of this metric. #[getter] fn name(&self) -> String { self.metric.value().name().to_string() } - /// Returns the numeric value of this metric, or None for non-numeric types. + /// Returns the numeric value of this metric as a `usize`, or `None` when the + /// value is not representable as an integer. + /// + /// # Note + /// `StartTimestamp` and `EndTimestamp` metrics are returned as nanoseconds + /// since the Unix epoch (via `timestamp_nanos_opt`), which may overflow + /// a `usize` on 32-bit platforms or return `None` if the timestamp is out + /// of range. Non-numeric metric variants (unrecognised future variants) + /// also return `None`. #[getter] fn value(&self) -> Option { match self.metric.value() { @@ -122,13 +122,11 @@ impl PyMetric { } } - /// Returns the partition this metric is for, or None if it applies to all partitions. #[getter] fn partition(&self) -> Option { self.metric.partition() } - /// Returns the labels associated with this metric as a dict. fn labels(&self) -> HashMap { self.metric .labels() From 7200857753b5871a2faf49a4623fbccd0097cc75 Mon Sep 17 00:00:00 2001 From: ShreyeshArangath Date: Thu, 19 Mar 2026 13:39:42 -0700 Subject: [PATCH 4/4] plan caching --- python/datafusion/plan.py | 7 +++++++ python/tests/test_plans.py | 39 ++++++++++++++++++++++++++++++++++++++ src/dataframe.rs | 16 +++++++++++++--- src/metrics.rs | 29 ++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 3 deletions(-) diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index bac437835..b616b4b8f 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -19,6 +19,8 @@ from __future__ import annotations +import datetime + from typing import TYPE_CHECKING, Any import datafusion._internal as df_internal @@ -290,6 +292,11 @@ def value(self) -> int | None: """ return self._raw.value + @property + def value_as_datetime(self) -> datetime.datetime | None: + """The value as a UTC datetime for timestamp metrics, or ``None``.""" + return self._raw.value_as_datetime() + @property def partition(self) -> int | None: """The 0-based partition index this metric applies to. diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 926cb3a40..f006ae3cc 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +import datetime + +import pytest + from datafusion import ( ExecutionPlan, LogicalPlan, @@ -150,3 +154,38 @@ def test_execute_stream_partitioned_metrics() -> None: if ms.output_rows is not None ] assert 2 in output_rows_values + + +def test_value_as_datetime() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + df.collect() + plan = df.execution_plan() + + for _, ms in plan.collect_metrics(): + for metric in ms.metrics(): + if metric.name in ("start_timestamp", "end_timestamp"): + dt = metric.value_as_datetime + assert dt is None or isinstance(dt, datetime.datetime) + if dt is not None: + assert dt.tzinfo is not None + else: + assert metric.value_as_datetime is None + + +def test_collect_twice_reuses_plan() -> None: + ctx = SessionContext() + ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + df = ctx.sql("SELECT * FROM t WHERE column1 > 1") + + df.collect() + df.collect() + + plan = df.execution_plan() + output_rows_values = [ + ms.output_rows + for _, ms in plan.collect_metrics() + if ms.output_rows is not None + ] + assert len(output_rows_values) > 0 diff --git a/src/dataframe.rs b/src/dataframe.rs index 3d43809de..eed7e11a7 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -387,9 +387,19 @@ impl PyDataFrame { &self, py: Python, ) -> PyDataFusionResult<(Arc, Arc)> { - let df = self.df.as_ref().clone(); - let plan = wait_for_future(py, df.create_physical_plan())??; - *self.last_plan.lock() = Some(Arc::clone(&plan)); + let plan = { + let cached = self.last_plan.lock(); + cached.as_ref().map(Arc::clone) + }; + let plan = match plan { + Some(p) => p, + None => { + let df = self.df.as_ref().clone(); + let new_plan = wait_for_future(py, df.create_physical_plan())??; + *self.last_plan.lock() = Some(Arc::clone(&new_plan)); + new_plan + } + }; let task_ctx = Arc::new(self.df.as_ref().task_ctx()); Ok((plan, task_ctx)) } diff --git a/src/metrics.rs b/src/metrics.rs index 8cd531a88..e7ab856ff 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -122,6 +122,35 @@ impl PyMetric { } } + /// Returns the value as a Python `datetime` for `StartTimestamp` / `EndTimestamp` + /// metrics, or `None` for all other metric types. + fn value_as_datetime<'py>(&self, py: Python<'py>) -> PyResult>> { + match self.metric.value() { + MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => { + match ts.value() { + Some(dt) => { + let nanos = dt.timestamp_nanos_opt() + .ok_or_else(|| PyErr::new::( + "timestamp out of range" + ))?; + let datetime_mod = py.import("datetime")?; + let datetime_cls = datetime_mod.getattr("datetime")?; + let tz_utc = datetime_mod.getattr("timezone")?.getattr("utc")?; + let secs = nanos / 1_000_000_000; + let micros = (nanos % 1_000_000_000) / 1_000; + let result = datetime_cls.call_method1( + "fromtimestamp", + (secs as f64 + micros as f64 / 1_000_000.0, tz_utc), + )?; + Ok(Some(result)) + } + None => Ok(None), + } + } + _ => Ok(None), + } + } + #[getter] fn partition(&self) -> Option { self.metric.partition()