diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py
index 2956afad02..e03462bf78 100644
--- a/pytensor/link/jax/dispatch/tensor_basic.py
+++ b/pytensor/link/jax/dispatch/tensor_basic.py
@@ -87,14 +87,7 @@ def jax_funcify_Join(op, **kwargs):
     def join(axis, *tensors):
         # tensors could also be tuples, and in this case they don't have a ndim
         tensors = [jnp.asarray(tensor) for tensor in tensors]
-        view = op.view
-        if (view != -1) and all(
-            tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :]
-        ):
-            return tensors[view]
-
-        else:
-            return jnp.concatenate(tensors, axis=axis)
+        return jnp.concatenate(tensors, axis=axis)
 
     return join
 
diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py
index 7daa625794..7749514e03 100644
--- a/pytensor/link/numba/dispatch/tensor_basic.py
+++ b/pytensor/link/numba/dispatch/tensor_basic.py
@@ -117,17 +117,9 @@ def arange(start, stop, step):
 
 @numba_funcify.register(Join)
 def numba_funcify_Join(op, **kwargs):
-    view = op.view
-
-    if view != -1:
-        # TODO: Where (and why) is this `Join.view` even being used?  From a
-        # quick search, the answer appears to be "nowhere", so we should
-        # probably just remove it.
-        raise NotImplementedError("The `view` parameter to `Join` is not supported")
-
     @numba_basic.numba_njit
     def join(axis, *tensors):
-        return np.concatenate(tensors, numba_basic.to_scalar(axis))
+        return np.concatenate(tensors, axis.item())
 
     return join
 
diff --git a/pytensor/scan/checkpoints.py b/pytensor/scan/checkpoints.py
index 8c237267d5..d974e8257e 100644
--- a/pytensor/scan/checkpoints.py
+++ b/pytensor/scan/checkpoints.py
@@ -1,6 +1,5 @@
 import pytensor.tensor.basic as ptb
 from pytensor.scan.basic import scan
-from pytensor.tensor.basic import Join
 from pytensor.tensor.math import ceil, eq, neq
 from pytensor.tensor.subtensor import set_subtensor
 
@@ -127,14 +126,12 @@ def scan_checkpoints(
 
     # Pad the sequences if needed
     if padding:
-        # Since padding could be an empty tensor, Join returns a view of s.
-        join = Join(view=0)
         for i, s in enumerate(sequences):
             overshoots_by = s.shape[0] % save_every_N
             overshoots = neq(overshoots_by, 0)
             n = (save_every_N - overshoots_by) * overshoots
             z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
-            sequences[i] = join(0, s, z)
+            sequences[i] = ptb.join(0, s, z)
 
     # Establish the input variables of the outer scan
     o_sequences = [
diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py
index 5d6c059c53..9117a0d99d 100644
--- a/pytensor/tensor/basic.py
+++ b/pytensor/tensor/basic.py
@@ -2412,37 +2412,17 @@ class Join(COp):
     The axis has to be an index into the shape
     >>> pt.join(2, x, y, z)
     Traceback (most recent call last):
-    ValueError: Axis value 2 is out of range for the given input dimensions
+    numpy.exceptions.AxisError: axis 2 is out of bounds for array of dimension 2
 
     Joined tensors must have the same rank
     >>> pt.join(0, x, u)
     Traceback (most recent call last):
-    TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1].
+    TypeError: Only tensors with the same number of dimensions can be joined
 
     """
 
     check_input = False
-    __props__ = ("view",)
-
-    def __init__(self, view=-1):
-        self.view = view
-        if view != -1:
-            # since the first input is always the axis, the tensors
-            # start from index 1.
-            self.view_map = {0: [1 + view]}
-
-    def __str__(self):
-        if self.view == -1:
-            return self.__class__.__name__
-        else:
-            classname = self.__class__.__name__
-            args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__)
-            return f"{classname}{{{args}}}"
-
-    def __setstate__(self, d):
-        self.__dict__.update(d)
-        if not hasattr(self, "view"):
-            self.view = -1
+    __props__ = ()
 
     def make_node(self, axis, *tensors):
         """
@@ -2459,74 +2439,60 @@ def make_node(self, axis, *tensors):
         if not tensors:
             raise ValueError("Cannot join an empty list of tensors")
 
+        axis = as_tensor_variable(axis)
+        if axis.type.dtype not in int_dtypes:
+            raise TypeError(f"Axis {axis} must be an integer type.")
+        if axis.type.ndim > 0:
+            raise TypeError(f"Axis {axis} must be 0-d.")
+
         tensors = [as_tensor_variable(x) for x in tensors]
-        out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
 
-        if not builtins.all(targs.type.ndim for targs in tensors):
+        if not builtins.all(targs.type.ndim > 0 for targs in tensors):
             raise TypeError(
                 "Join cannot handle arguments of dimension 0."
-                " Use `stack` to join scalar values."
+                " Use `stack` to join scalar values and/or increase rank of scalars."
             )
 
         if len(tensors) == 1:
             out_shape = tensors[0].type.shape
         else:
-            # When the axis is fixed, a dimension should be
-            # broadcastable if at least one of the inputs is
-            # broadcastable on that dimension (see justification below),
-            # except for the axis dimension.
-            # Initialize bcastable all false, and then fill in some trues with
-            # the loops.
-
-            if not isinstance(axis, int):
-                try:
-                    axis = int(get_scalar_constant_value(axis))
-                except NotScalarConstantError:
-                    pass
-
             ndim = tensors[0].type.ndim
-            if isinstance(axis, int):
-                # Basically, broadcastable -> length 1, but the
-                # converse does not hold. So we permit e.g. T/F/T
-                # joins, and if they fail at runtime they fail, but if
-                # they don't then it means that the argument where
-                # that broadcastable flag was False had length 1 along
-                # this dimension, and therefore this dimension should
-                # be broadcastable for the output.
-
-                if axis < -ndim:
-                    raise IndexError(
-                        f"Axis value {axis} is out of range for the given input dimensions"
-                    )
-                if axis < 0:
-                    axis += ndim
-                if axis > ndim - 1:
-                    raise ValueError(
-                        f"Axis value {axis} is out of range for the given input dimensions"
-                    )
-                # NOTE: Constant negative axis can no longer be negative at this point.
-
-                in_shapes = [x.type.shape for x in tensors]
-                in_ndims = [len(s) for s in in_shapes]
-                if set(in_ndims) != {ndim}:
-                    raise TypeError(
-                        "Only tensors with the same number of dimensions can be joined."
-                        f" Input ndims were: {in_ndims}."
-                    )
+
+            if not builtins.all(x.ndim == ndim for x in tensors):
+                raise TypeError(
+                    "Only tensors with the same number of dimensions can be joined"
+                )
+
+            try:
+                static_axis = int(get_scalar_constant_value(axis))
+            except NotScalarConstantError:
+                static_axis = None
+
+            if static_axis is None:
+                # When axis isn't static, we can't canclude anything about output dimension
+                # (unless we had some degenerate zero arrays) that can be removed during rewrites.
+                # We could also raise errors if any dimensions are pairwise inconsistent across all the axes
+                # As no matter the join it would be invalid.
+                # However, dynamic axis is so rare that is not worth the trouble
+                out_shape = [None] * ndim
+
+            else:  # We know the axis statically
+                static_axis = normalize_axis_index(static_axis, ndim)
+                static_shapes = [x.type.shape for x in tensors]
 
                 # Determine output shapes from a matrix of input shapes
-                in_shapes = np.array(in_shapes)
+                static_shapes = np.array(static_shapes)
                 out_shape = [None] * ndim
                 for d in range(ndim):
-                    ins = in_shapes[:, d]
-                    if d == axis:
-                        # Any unknown size along the axis means we can't sum
+                    ins = static_shapes[:, d]
+                    if d == static_axis:
+                        # Any unknown size along the axis means we can't infer it
                         if None in ins:
                             out_shape[d] = None
                         else:
                             out_shape[d] = sum(ins)
                     else:
-                        inset = set(in_shapes[:, d])
+                        inset = set(static_shapes[:, d])
                         # Other dims must match exactly,
                         # or if a mix of None and ? the output will be ?
                         # otherwise the input shapes are incompatible.
@@ -2536,100 +2502,141 @@ def make_node(self, axis, *tensors):
                             (out_shape[d],) = inset - {None}
                         else:
                             raise ValueError(
-                                f"all input array dimensions other than the specified `axis` ({axis})"
+                                f"all input array dimensions other than the specified `axis` ({static_axis})"
                                 " must match exactly, or be unknown (None),"
                                 f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
                             )
-            else:
-                # When the axis may vary, no dimension can be guaranteed to be
-                # broadcastable.
-                out_shape = [None] * tensors[0].type.ndim
 
-            if not builtins.all(x.ndim == len(out_shape) for x in tensors):
-                raise TypeError(
-                    "Only tensors with the same number of dimensions can be joined"
-                )
-
-        inputs = [as_tensor_variable(axis), *tensors]
+        inputs = [axis, *tensors]
+        out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
+        return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])
 
-        if inputs[0].type.dtype not in int_dtypes:
-            raise TypeError(f"Axis value {inputs[0]} must be an integer type")
+    def perform(self, node, inputs, output_storage):
+        axis, *arrays = inputs
+        output_storage[0][0] = np.concatenate(
+            arrays, axis=axis, dtype=node.outputs[0].type.dtype
+        )
 
-        return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])
+    def c_code_cache_version(self):
+        return (7,)
 
-    def perform(self, node, axis_and_tensors, out_):
-        (out,) = out_
-        view = self.view
-        axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
-        # we check these tensors for being empty.
-        if (view != -1) and all(
-            tensor.shape[axis] == 0 for tensor in tens[0:view] + tens[view + 1 :]
-        ):
-            out[0] = tens[view]
+    def c_code(self, node, name, inputs, outputs, sub):
+        axis, *arrays = inputs
+        [out] = outputs
+        n = len(arrays)
+        ndim = node.outputs[0].type.ndim
+        fail = sub["fail"]
 
+        # Most times axis is constant, inline it
+        # This is safe to do because the hash of the c_code includes the constant signature
+        if isinstance(node.inputs[0], Constant):
+            static_axis = int(node.inputs[0].data)
+            static_axis = normalize_axis_index(static_axis, ndim)
+            axis_def = f"{static_axis};"
+            axis_check = ""
         else:
-            ndim = tens[0].ndim
-            if axis < -ndim:
-                raise IndexError(
-                    f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
-                )
+            axis_dtype = node.inputs[0].type.dtype_specs()[1]
+            axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
+            axis_check = f"""
+                if (axis < 0){{
+                    axis = {ndim} + axis;
+                }}
+                if (axis >= {ndim} || axis < 0) {{
+                    PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
+                    {fail}
+                }}
+            """
 
-            out[0] = np.asarray(
-                np.concatenate(tens, axis=axis), dtype=node.outputs[0].type.dtype
+        copy_arrays_to_tuple = "\n".join(
+            (
+                f"""Py_INCREF({array}); PyTuple_SetItem(arrays_tuple, {i}, (PyObject*){array});"""
+                for i, array in enumerate(arrays)
             )
+        )
 
-    def c_code_cache_version(self):
-        return (5,)
+        code = f"""
+        int axis = {axis_def}
+        PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
+        int out_is_valid = {out} != NULL;
 
-    def c_code(self, node, name, inputs, outputs, sub):
-        axis, tens = inputs[0], inputs[1:]
-        view = self.view
-        non_empty_tensor = tens[view]
-        input_1 = tens[0]
-        l = len(tens)
-        (out,) = outputs
-        fail = sub["fail"]
-        adtype = node.inputs[0].type.dtype_specs()[1]
+        {axis_check}
 
-        copy_to_list = (
-            f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
-            for i, inp in enumerate(tens)
-        )
+        if (out_is_valid) {{
+            // Check if we can reuse output
+            npy_intp join_size = 0;
+            npy_intp out_shape[{ndim}];
+            npy_intp *shape = PyArray_SHAPE(arrays[0]);
 
-        copy_inputs_to_list = "\n".join(copy_to_list)
-        n = len(tens)
+            for (int i = 0; i < {n}; i++) {{
+                if (PyArray_NDIM(arrays[i]) != {ndim}) {{
+                    PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
+                    {fail}
+                }}
 
-        code = f"""
-        int axis = (({adtype} *)PyArray_DATA({axis}))[0];
-        PyObject* list = PyList_New({l});
-        {copy_inputs_to_list}
-        int tensors_lens_sum;
-        if({view} != -1) {{
-            tensors_lens_sum = 0;
-
-            for(int i=0; i < {n}; i++){{
-                tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
+                join_size += PyArray_SHAPE(arrays[i])[axis];
+
+                if (i > 0){{
+                    for (int j = 0; j < {ndim}; j++) {{
+                        if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
+                            PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
+                            {fail}
+                        }}
+                    }}
+                }}
+            }}
+
+            memcpy(out_shape, shape, {ndim} * sizeof(npy_intp));
+            out_shape[axis] = join_size;
+
+            for (int i = 0; i < {ndim}; i++) {{
+                out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]);
             }}
-            tensors_lens_sum -= PyArray_DIM({non_empty_tensor}, axis);
         }}
-        if({view} != -1 && tensors_lens_sum == 0) {{
+
+        if (!out_is_valid) {{
+            // Use PyArray_Concatenate
             Py_XDECREF({out});
-            Py_INCREF({non_empty_tensor});
-            {out} = {non_empty_tensor};
-        }}else{{
-            //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
-            int ndim = PyArray_NDIM({input_1});
-            if( axis < -ndim ){{
-                PyErr_Format(PyExc_IndexError,
-                             "Join axis %d out of bounds [0, %d)", axis, ndim);
+            PyObject* arrays_tuple = PyTuple_New({n});
+            {copy_arrays_to_tuple}
+            {out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
+            Py_DECREF(arrays_tuple);
+            if(!{out}){{
                 {fail}
             }}
-            Py_XDECREF({out});
-            {out} = (PyArrayObject *)PyArray_Concatenate(list, axis);
-            Py_DECREF(list);
-            if(!{out}){{
+        }}
+        else {{
+            // Copy the data to the pre-allocated output buffer
+
+            // Create view into output buffer
+            PyArrayObject_fields *view;
+
+            // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
+            Py_INCREF(PyArray_DESCR({out}));
+            view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
+                                                                  PyArray_DESCR({out}),
+                                                                  {ndim},
+                                                                  PyArray_SHAPE(arrays[0]),
+                                                                  PyArray_STRIDES({out}),
+                                                                  PyArray_DATA({out}),
+                                                                  NPY_ARRAY_WRITEABLE,
+                                                                  NULL);
+            if (view == NULL) {{
                 {fail}
             }}
+
+            // Copy data into output buffer
+            for (int i = 0; i < {n}; i++) {{
+                view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
+
+                if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
+                    Py_DECREF(view);
+                    {fail}
+                }}
+
+                view->data += (view->dimensions[axis] * view->strides[axis]);
+            }}
+
+            Py_DECREF(view);
         }}
         """
         return code
@@ -2639,22 +2646,21 @@ def R_op(self, inputs, eval_points):
             return [None]
         return self.make_node(inputs[0], *eval_points[1:]).outputs
 
-    def grad(self, axis_and_tensors, grads):
+    def L_op(self, inputs, outputs, grads):
         """The gradient wrt a join op is a `Split`, used to partition
         the gradient along the `axis` which was used for joining.
         """
-        (gz,) = grads
-        axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
+        [gz] = grads
+        [out] = outputs
+        axis, *tensors = inputs
 
         rval = [grad_undefined(self, 0, axis)]
-
-        dtypes = [as_tensor_variable(x).type.dtype for x in tens]
-        out_dtype = ps.upcast(*dtypes)
+        out_dtype = out.type.dtype
 
         if "float" in out_dtype or "complex" in out_dtype:
             # assume that this is differentiable
-            split = Split(len(tens))
-            split_gz = split(gz, axis, stack([shape(x)[axis] for x in tens]))
+            split_sizes = stack([shape(x)[axis] for x in tensors])
+            split_gz = split(gz, split_sizes, n_splits=len(tensors), axis=axis)
             # If there is only one split, it might not be in a list.
             if not isinstance(split_gz, list):
                 split_gz = [split_gz]
@@ -2667,13 +2673,12 @@ def grad(self, axis_and_tensors, grads):
                 else specify_broadcastable(
                     g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1)
                 )
-                for t, g in zip(tens, split_gz, strict=True)
+                for t, g in zip(tensors, split_gz, strict=True)
             ]
             rval = rval + split_gz
         else:
-            # the output has integer type, so the gradient through it
-            # is 0
-            rval = rval + [t.zeros_like(dtype=config.floatX) for t in tens]
+            # the output has integer type, so the gradient through it is 0
+            rval = rval + [t.zeros_like(dtype=config.floatX) for t in tensors]
 
         return rval
 
@@ -2693,7 +2698,8 @@ def infer_shape(self, fgraph, node, ishapes):
         # An axis < -n_dim or >= ndim would be invalid, but this is
         # not checked here. A `CheckAndRaise` `Op` would be a way of
         # addressing that, but it may disrupt optimizations.
-        join_dim = switch(ge(node.inputs[0], 0), node.inputs[0], node.inputs[0] + n_dim)
+        axis = node.inputs[0]
+        join_dim = switch(ge(axis, 0), axis, axis + n_dim)
         out_shapes = []
         for dim in range(n_dim):
             # we have to deal with 2 possible cases in here :
@@ -2716,7 +2722,7 @@ def infer_shape(self, fgraph, node, ishapes):
         return [tuple(out_shapes)]
 
 
-join_ = Join()
+_join = Join()
 pprint.assign(Join, printing.FunctionPrinter(["join"]))
 
 
@@ -2759,7 +2765,7 @@ def join(axis, *tensors_list):
     if len(tensors_list) == 1:
         return tensors_list[0]
     else:
-        return join_(axis, *tensors_list)
+        return _join(axis, *tensors_list)
 
 
 @_vectorize_node.register(Join)
diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py
index 59148fae3b..61db37bd27 100644
--- a/pytensor/tensor/rewriting/basic.py
+++ b/pytensor/tensor/rewriting/basic.py
@@ -41,6 +41,7 @@
     node_rewriter,
 )
 from pytensor.graph.rewriting.db import RewriteDatabase
+from pytensor.npy_2_compat import normalize_axis_index
 from pytensor.raise_op import Assert, CheckAndRaise, assert_op
 from pytensor.scalar.basic import Second
 from pytensor.tensor.basic import (
@@ -817,52 +818,38 @@ def local_join_1(fgraph, node):
         return [tensors[0]]
 
 
-# TODO: merge in local_useless_join
-@register_infer_shape
 @register_useless
-@register_specialize
 @register_canonicalize
+@register_specialize
 @node_rewriter([Join])
 def local_join_empty(fgraph, node):
     """Join(i, x, y, empty) => Join(i, x, y)
 
     Remove empty inputs to joins. The empty inputs can be anywhere.
-
     """
-    if not isinstance(node.op, Join):
-        return
-    new_inputs = []
+    axis, *tensors = node.inputs
+
     try:
-        join_idx = get_scalar_constant_value(
+        static_axis = get_scalar_constant_value(
             node.inputs[0], only_process_constants=True
         )
     except NotScalarConstantError:
         return
-    for idx in range(1, len(node.inputs)):
-        inp = node.inputs[idx]
-        # We can not use size == 0,, as this can change shape from 3,0
-        # to 2,0.  This trigger DebugMode error. This happen with
-        # stack(...,[]) as this add a dimshuffle on [], that add a
-        # dimensions with shape 1.
-        if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0:
-            continue
-        new_inputs.append(inp)
-    if len(new_inputs) < len(node.inputs) - 1:
-        if len(new_inputs) == 0:
-            # at.join do not work in that case.
-            # constant folding will take care of this case.
-            return
-        ret = join(node.inputs[0], *new_inputs)
-        o = node.outputs[0]
-        if ret.dtype != o.dtype:
-            # Join can upcast some inputs
-            return
 
-        # Copy over stacktrace from previous output (after join op)
-        # to new output, because an error in the new op must be caused
-        # by an error in the old join op.
-        copy_stack_trace(node.outputs, ret)
+    new_tensors = [tensor for tensor in tensors if tensor.type.shape[static_axis] != 0]
+
+    # If there are zero tensors, the join is useless but so is any other operation
+    # Another rewrite will (one day) handle all those cases
+    if 0 < len(new_tensors) < len(tensors):
+        # join eagerly returns a tensor when there is only one, no need for us to check
+        ret = join(axis, *new_tensors)
+
+        [old_output] = node.outputs
+
+        if ret.dtype != old_output.dtype:
+            ret = ret.astype(old_output.dtype)
 
+        copy_stack_trace(old_output, ret)
         return [ret]
 
 
@@ -1298,7 +1285,7 @@ def local_join_of_alloc(fgraph, node):
     # Axis can never be lifted
     # Non-axis allocated dimensions can be lifted if they are all broadcastable
     [out] = node.outputs
-    axis = axis.data
+    static_axis = normalize_axis_index(axis.data, tensors[0].type.ndim)
 
     broadcasted_dims = list(
         zip(
@@ -1320,7 +1307,7 @@ def local_join_of_alloc(fgraph, node):
     lifteable_alloc_dims = {
         dim
         for dim in range(out.type.ndim)
-        if dim != axis and all(broadcasted_dims[dim])
+        if dim != static_axis and all(broadcasted_dims[dim])
     }
 
     if not lifteable_alloc_dims:
@@ -1337,13 +1324,13 @@ def local_join_of_alloc(fgraph, node):
         copy_stack_trace(tensor, new_tensor)
         new_tensors.append(new_tensor)
 
-    new_join = node.op(axis, *new_tensors)
+    new_join = node.op(static_axis, *new_tensors)
     copy_stack_trace(node.outputs[0], new_join)
 
     # Reintroduce the lifted dims
     post_join_shape = []
     for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)):
-        if i == axis:
+        if i == static_axis:
             # The alloc dim along the axis is the sum of all the pre-join alloc dims
             post_join_shape.append(add(*alloc_dims))
         else:
diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py
index 09963f9d36..625246e340 100644
--- a/tests/link/numba/test_tensor_basic.py
+++ b/tests/link/numba/test_tensor_basic.py
@@ -172,24 +172,6 @@ def test_Join(vals, axis):
     )
 
 
-def test_Join_view():
-    vals, vals_test = zip(
-        *(
-            (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
-            (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
-        ),
-        strict=True,
-    )
-    g = ptb.Join(view=1)(1, *vals)
-
-    with pytest.raises(NotImplementedError):
-        compare_numba_and_py(
-            vals,
-            g,
-            vals_test,
-        )
-
-
 @pytest.mark.parametrize(
     "n_splits, axis, values, sizes",
     [
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index 1730ae46ac..a959efd6d3 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -1248,65 +1248,41 @@ def test_local_join_1():
 
 
 def test_local_join_empty():
-    # test for vector, vector, empty to vector
+    # Vector case
     empty_vec = np.asarray([], dtype=config.floatX)
-    a = vector("a")
-    s = pt.join(0, a, a, empty_vec)
-    f = function([a], s, mode=rewrite_mode)
-    val = f([1])
-    assert np.all(val == [1])
-    e = f.maker.fgraph.toposort()
-    assert len([n for n in e if isinstance(n.op, Join)]) == 1
-    assert all(
-        not isinstance(n.op, Join) or len(n.inputs) == 3
-        for n in e
-        if isinstance(n.op, Join)
+    vec = vector("vec")
+    s = pt.join(0, vec, vec, empty_vec)
+    new_s = rewrite_graph(s)
+    assert equal_computations([new_s], [join(0, vec, vec)])
+    assert new_s.dtype == s.dtype
+
+    # Matrix case
+    empty_mat = np.zeros((2, 0), dtype=config.floatX)
+    empty_sym_mat = matrix("m", shape=(2, 0))
+    mat = matrix("mat", shape=(2, 10))
+    s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
+    new_s = rewrite_graph(s)
+    assert equal_computations([new_s], [join(1, mat, mat, mat)])
+    assert new_s.dtype == s.dtype
+
+    # Join can be completely removed, but casting and specify_shape are propagated
+    int_mat = matrix("int_mat", dtype=int)
+    s = join(-1, empty_mat, int_mat, empty_sym_mat)
+    new_s = rewrite_graph(s)
+    assert equal_computations(
+        [new_s], [specify_shape(int_mat, (2, None)).astype(s.dtype)]
     )
-    assert f.maker.fgraph.outputs[0].dtype == config.floatX
 
-    # test for matrix join(1,a)
-    empty_mat = np.asarray([[]], dtype=config.floatX)
-    m = matrix("m")
-    s = join(1, empty_mat, m, m, m)
-    f = function([m], s, mode=rewrite_mode)
-    val = f([[1]])
-    assert np.all(val == [[1]])
-    e = f.maker.fgraph.toposort()
-    assert len([n for n in e if isinstance(n.op, Join)]) == 1
-    assert all(
-        not isinstance(n.op, Join) or len(n.inputs) == 4
-        for n in e
-        if isinstance(n.op, Join)
-    )
-    assert f.maker.fgraph.outputs[0].dtype == config.floatX
-    # test for vector, vector, empty to matrix
-    # We can't rewrite this case.
-    s = pt.stack([a, a, empty_vec])
-    f = function([a], s, mode=rewrite_mode)
-    val = f([])
-    assert np.all(val == [1])
-    e = f.maker.fgraph.toposort()
-    assert len([n for n in e if isinstance(n.op, Join)]) == 1
-    assert all(
-        not isinstance(n.op, Join) or len(n.inputs) == 4
-        for n in e
-        if isinstance(n.op, Join)
-    )
-    assert f.maker.fgraph.outputs[0].dtype == config.floatX
-    # test for matrix join(0,a)
-    # We can't rewrite this case.
-    s = join(0, m, np.asarray([[2.0]], dtype=config.floatX), m)
-    f = function([m], s, mode=rewrite_mode)
-    val = f([[1]])
-    assert np.all(val == [[1], [2], [1]])
-    e = f.maker.fgraph.toposort()
-    assert len([n for n in e if isinstance(n.op, Join)]) == 1
-    assert all(
-        not isinstance(n.op, Join) or len(n.inputs) == 4
-        for n in e
-        if isinstance(n.op, Join)
-    )
-    assert f.maker.fgraph.outputs[0].dtype == config.floatX
+    # Dynamic axis, can't apply rewrite
+    axis = scalar("axis", dtype=int)
+    s = join(axis, empty_mat, int_mat, empty_sym_mat)
+    new_s = rewrite_graph(s)
+    assert equal_computations([new_s], [s])
+
+    # Stack introduces an expand_dims in the join, that's a nonzero dim!
+    s = pt.stack([vec, vec, empty_vec])
+    new_s = rewrite_graph(s)
+    assert equal_computations([new_s], [s])
 
 
 def test_local_join_make_vector():
diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py
index e29a47691a..c8c2bb224a 100644
--- a/tests/tensor/test_basic.py
+++ b/tests/tensor/test_basic.py
@@ -117,6 +117,7 @@
     ivector,
     lscalar,
     lvector,
+    matrices,
     matrix,
     row,
     scalar,
@@ -1762,7 +1763,7 @@ def test_join_matrixV_negative_axis(self):
         got = f(-2)
         assert np.allclose(got, want)
 
-        with pytest.raises(IndexError):
+        with pytest.raises(ValueError):
             f(-3)
 
     @pytest.mark.parametrize("py_impl", (False, True))
@@ -1805,7 +1806,7 @@ def test_join_matrixC_negative_axis(self, py_impl):
         got = f()
         assert np.allclose(got, want)
 
-        with pytest.raises(IndexError):
+        with pytest.raises(ValueError):
             join(-3, a, b)
 
         with impl_ctxt:
@@ -2118,28 +2119,6 @@ def test_split_static_shape(self):
         y = Split(2)(x, 0, [s, 5 - s])[0]
         assert y.type.shape == (None,)
 
-    def test_join_inplace(self):
-        # Test join to work inplace.
-        #
-        # This function tests the case when several elements are passed to the
-        # join function but all except one of them are empty. In this case join
-        # should work inplace and the output should be the view of the non-empty
-        # element.
-        s = lscalar()
-        x = vector("x")
-        z = ptb.zeros((s,))
-
-        join = Join(view=0)
-        c = join(0, x, z, z)
-
-        f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))
-
-        data = np.array([3, 4, 5], dtype=config.floatX)
-
-        if config.mode not in ["DebugMode", "DEBUG_MODE"]:
-            assert f(data, 0) is data
-        assert np.allclose(f(data, 0), [3, 4, 5])
-
     def test_join_oneInput(self):
         # Test join when only 1 input is given.
         #
@@ -2174,6 +2153,32 @@ def test_split_view(self, linker):
             assert np.allclose(r, expected)
             assert r.base is x_test
 
+    @pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}")
+    @pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"])
+    @pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}")
+    @pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"])
+    @config.change_flags(cmodule__warn_no_version=False)
+    def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
+        if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0):
+            pytest.skip("Redundant parametrization")
+        n = 64
+        inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef")
+        out = join(axis, *inputs)
+        fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True)
+        fn.vm.allow_gc = gc
+        test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs]
+        if memory_layout == "C-contiguous":
+            pass
+        elif memory_layout == "F-contiguous":
+            test_values = [t.T for t in test_values]
+        elif memory_layout == "Mixed":
+            test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)]
+        else:
+            raise ValueError
+
+        assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
+        benchmark(fn, *test_values)
+
 
 def test_TensorFromScalar():
     s = ps.constant(56)